Spreading Depolarization in Brain Slices (Kelley et al. 2022)

 Download zip file 
Help downloading and running models
Accession:267259
A tissue-scale model of spreading depolarization (SD) in brain slices. We used the NEURON simulator's reaction-diffusion framework to implement embed thousands of neurons (based on the the model from Wei et al. 2014) in the extracellular space of a brain slice, which is itself embedded in an bath solution. We initiate SD in the slice by elevating extracellular K+ in a spherical region at the center of the slice. Effects of hypoxia and propionate on the slice were modeled by appropriate changes to the volume fraction and tortuosity of the extracellular space and oxygen/chloride concentrations.
Reference:
1 . Kelley C, Newton AJH, Hrabetova S, McDougal RA, Lytton WW (2022) Multiscale Computer Modeling of Spreading Depolarization in Brain Slices eNeuro [PubMed]
Model Information (Click on a link to find other models with that property)
Model Type: Extracellular; Neuron or other electrically excitable cell; Glia;
Brain Region(s)/Organism:
Cell Type(s):
Channel(s): Na/K pump; NKCC1; KCC2; I Na, leak; I Cl, leak; I K,leak; I K; I Na,t;
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: NEURON;
Model Concept(s): Spreading depolarization; Spreading depression; Reaction-diffusion;
Implementer(s): Kelley, Craig; Newton, Adam J H [adam.newton at yale.edu]; Lytton, William [bill.lytton at downstate.edu]; McDougal, Robert [robert.mcdougal at yale.edu];
Search NeuronDB for information about:  I Na,t; I K; I K,leak; Na/K pump; I Cl, leak; I Na, leak; KCC2; NKCC1;
import numpy as np
from matplotlib import pyplot as plt 
import os
import imageio 
from scipy.signal import find_peaks 
import pickle
import multiprocessing

def totalSurfaceArea(Lx, Ly, Lz, density, alphaNrn, sa2v, rs=None):
    """Computes the total neuronal surface area given dimensions of the slice,
    neuronal density, neuronal volume fraction (alphaNrn), and neuronal S:V ratio"""
    Vtissue = Lx*Ly*Lz 
    # cell numbers 
    Ncell = int(density*(Lx*Ly*Lz*1e-9))
    if not rs:
        rs = ((alphaNrn*Vtissue)/(2*np.pi*Ncell)) ** (1/3)
    somaR = (sa2v * rs**3 / 2.0) ** (1/2)
    a = 4 * np.pi * somaR**2
    return a * Ncell 
    
def totalCellVolume(Lx, Ly, Lz, density, alphaNrn, sa2v, rs=None):
    """Computes the total neuronal volume given dimensions of the slice,
    neuronal density, neuronal volume fraction (alphaNrn), and neuronal S:V ratio"""
    Vtissue = Lx*Ly*Lz 
    # cell numbers 
    Ncell = int(density*(Lx*Ly*Lz*1e-9))
    if not rs:
        rs = ((alphaNrn*Vtissue)/(2*np.pi*Ncell)) ** (1/3)
    somaR = (sa2v * rs**3 / 2.0) ** (1/2)
    cyt_frac = rs**3 / somaR**3
    v = 2 * np.pi * somaR**3 * cyt_frac
    return v * Ncell

def combineRecs(dirs, recNum=None):
    """tool for combining recordings from multiple recordings.  dirs is a list of
    directories with data.  intended for use with state saving/restoring"""
    # open first file 
    if recNum:
        filename = 'recs' + str(recNum) + '.pkl'
    else:
        filename = 'recs.pkl'

    with open(dirs[0]+filename,'rb') as fileObj:
        data = pickle.load(fileObj)
    ## convert first file to all lists 
    for key in data.keys():
        if isinstance(data[key], list):
            for i in range(len(data[key])):
                try:
                    data[key][i] = list(data[key][i])
                except:
                    pass
        else:
            data[key] = list(data[key])
    # load each new data file
    for datadir in dirs[1:]:
        with open(datadir+filename, 'rb') as fileObj:
            new_data = pickle.load(fileObj)
        ## extend lists in original data with new data
        for key in new_data.keys():
            if isinstance(new_data[key], list):
                for i in range(len(new_data[key])):
                    try:
                        data[key][i].extend(list(new_data[key][i]))
                    except:
                        pass
            else:
                data[key].extend(list(new_data[key]))
    return data

def waveSpeedVs(dirs, xaxis, xlabel, ystop=None, figname='wavespeed'):
    # computes and plots wave speed vs. some specified parameter 
    for datadir in dirs:
        times = []
        wave_pos = []
        ## single run
        if not isinstance(datadir, list):
            f = open(datadir + 'wave_progress.txt', 'r')
            for line in f.readlines():
                times.append(float(line.split()[0]))
                wave_pos.append(float(line.split()[-2]))
            f.close()
        else:
            ## fragmented run 
            for d in datadir:
                f = open(d + 'wave_progress.txt', 'r')
                for line in f.readlines():
                    times.append(float(line.split()[0]))
                    wave_pos.append(float(line.split()[-2]))
                f.close()
        ## trim if needed 
        if ystop:
            times = [t for t, w in zip(times, wave_pos) if w < ystop]
            wave_pos = [w for w in wave_pos if w < ystop]
        

def compareDiffuse(dirs, outpath, species='k', figname='kconc', dur=10, start=0, vmin=3.5, vmax=40, extent=None):
    """generates gif(s) of K+ diffusion"""
    try:
        os.mkdir(outpath)
    except:
        pass
    
    files = [species+'_'+str(i)+'.npy' for i in range(int(start*1000),int(dur*1000)) if (i%100)==0]    
    for filename in files:
        plt.figure(figsize=(8,10))
        for ind in range(len(dirs)):
            plt.subplot(len(dirs), 1, ind+1)
            if isinstance(dirs[ind], list):
                try:
                    data = np.load(dirs[ind][0]+filename)
                except:
                    data = np.load(dirs[ind][1]+filename)
            else:
                data = np.load(dirs[ind]+filename)
            if extent:
                plt.imshow(data.mean(2), vmin=vmin, vmax=vmax, interpolation='nearest', origin='lower', extent=extent)     
            else:
                plt.imshow(data.mean(2), vmin=vmin, vmax=vmax, interpolation='nearest', origin='lower')     
            # plt.title(dirs[ind][:-1])
            t = str(float(filename.split('.')[0].split('_')[1]) / 1000) + ' s'
            plt.title(t, fontsize=18)
        
        plt.tight_layout()
        plt.savefig(outpath + filename[:-4] + '.png')
        plt.close()

    times = []
    filenames = os.listdir(path=outpath)
    for file in filenames: 
        times.append(float(file.split('_')[-1].split('.')[0])) 
    inds = np.argsort(times)
    filenames_sort = [filenames[i] for i in inds]
    imagesc = []
    for filename in filenames_sort: 
        imagesc.append(imageio.imread(outpath+filename)) 
    imageio.mimsave(figname + '.gif', imagesc)

def getKwaveSpeed(datadir, r0=0, rmax=None, tcut=None):
    """computes K+ wave speed, handles lists of data folders, drops points below r0"""
    if not isinstance(datadir, list):
        f = open(datadir + 'wave_progress.txt', 'r')
        times = []
        wave_pos = []
        for line in f.readlines():
            times.append(float(line.split()[0]))
            wave_pos.append(float(line.split()[-2]))
        f.close()
        if wave_pos:
            if tcut:
                wave_pos = [w for t, w in zip(times, wave_pos) if t <= tcut * 1000]
                times = [t for t in times if t <= tcut * 1000]
            print('Initial: %.3f' % (wave_pos[0]))
            print('Final: %.3f' % (wave_pos[-1]))
            if rmax:
                wave_pos_cut = [w for w in wave_pos if w > rmax]
                t_cut = [t for t, w in zip(times, wave_pos) if w > rmax]
                if len(wave_pos_cut):
                    m = (wave_pos_cut[0] - wave_pos[0]) / (t_cut[0] - times[0])
                elif np.max(wave_pos) > 200:
                    m = (np.max(wave_pos) - wave_pos[0]) / (times[np.argmax(wave_pos)]-times[0])
                else:
                    m = 0
            else:
                m = (np.max(wave_pos) - wave_pos[0]) / (times[np.argmax(wave_pos)]-times[0])
            m = m * 1000 / 16.667
        return m
    else:
        speeds = []
        for d in datadir:
            f = open(d + 'wave_progress.txt', 'r')
            times = []
            wave_pos = []
            for line in f.readlines():
                times.append(float(line.split()[0]))
                wave_pos.append(float(line.split()[-2]))
            f.close()
            if wave_pos:
                if tcut:
                    wave_pos = [w for t, w in zip(times, wave_pos) if t <= tcut * 1000]
                    times = [t for t in times if t <= tcut * 1000]
                print('Initial: %.3f' % (wave_pos[0]))
                print('Final: %.3f' % (wave_pos[-1]))
                if rmax:
                    wave_pos_cut = [w for w in wave_pos if w > rmax]
                    t_cut = [t for t, w in zip(times, wave_pos) if w > rmax]
                    if len(wave_pos_cut):
                        m = (wave_pos_cut[0] - wave_pos[0]) / (t_cut[0] - times[0])
                    elif np.max(wave_pos) > 200:
                        m = (np.max(wave_pos) - wave_pos[0]) / (times[np.argmax(wave_pos)]-times[0])
                    else:
                        m = 0
                else:
                    m = (np.max(wave_pos) - wave_pos[0]) / (times[np.argmax(wave_pos)]-times[0])
                m = m * 1000 / 16.667
            speeds.append(m)
        return speeds 

def getSpkWaveSpeed(datadir, pos, r0=0):
    """compute SD wave speed from spiking activity, handles lists, drops spikes within r0 """
    radius = []
    spktimes = []
    if not isinstance(datadir, list):
        m = getSpkMetrics(datadir, uniform=True, position=pos)
        r = [k for k in m.keys() if k > r0]
        t2firstSpk = [m[k]['t2firstSpk'] for k in m.keys() if k > r0]
        slope, b = np.polyfit(t2firstSpk, r, 1)
        return slope / 16.667
    else:
        speeds = []
        metrics = [getSpkMetrics(d, uniform=True, position=p) for d, p in zip(datadir,pos)]
        for d in datadir:
            for m in metrics:
                r = [k for k in m.keys() if k > r0]
                t2firstSpk = [m[k]['t2firstSpk'] for k in m.keys() if k > r0]
                slope, b = np.polyfit(t2firstSpk, r, 1)
                speeds.append(slope / 16.667)
        return speeds

def compareKwaves(dirs, labels, legendTitle, colors=None, trimDict=None, sbplt=None):
    """plots K+ wave trajectories from sims stored in list of folders dirs"""
    # plt.figure(figsize=(10,6))
    for d, l, c in zip(dirs, labels, colors):
        f = open(d + 'wave_progress.txt', 'r')
        times = []
        wave_pos = []
        for line in f.readlines():
            times.append(float(line.split()[0]))
            wave_pos.append(float(line.split()[-2]))
        f.close()
        if sbplt:
            plt.subplot(sbplt)
        if trimDict:
            if d in trimDict.keys():
                plt.plot(np.divide(times,1000)[:trimDict[d]], wave_pos[:trimDict[d]], label=l, linewidth=5, color=c)
            else:
                plt.plot(np.divide(times,1000), wave_pos, label=l, linewidth=5, color=c)
        else:
            if colors:
                plt.plot(np.divide(times,1000), wave_pos, label=l, linewidth=5, color=c)
            else:
                plt.plot(np.divide(times,1000), wave_pos, label=l, linewidth=5)
    # legend = plt.legend(title=legendTitle, fontsize=12)#, bbox_to_anchor=(-0.2, 1.05))
    legend = plt.legend(fontsize=12, loc='upper left')#, bbox_to_anchor=(-0.2, 1.05))
    # plt.setp(legend.get_title(), fontsize=14)
    plt.ylabel('K$^+$ Wave Position ($\mu$m)', fontsize=16)
    plt.xlabel('Time (s)', fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)

def plotRasters(dirs, figname='raster_plot.png'):
    """Plot raster(s) ordered by radial distance from the core of the slice for data stored in list
    of folders - dirs"""
    raster_fig = plt.figure(figsize=(16,8))
    for ind, datadir in enumerate(dirs):
        raster = getRaster(datadir)
        plt.subplot(len(dirs), 1, ind+1) 
        for key in raster.keys():
            plt.plot(np.divide(raster[key],1000), [key for i in range(len(raster[key]))], 'b.')
        plt.xlabel('Time (s)', fontsize=16)
        plt.ylabel('Distance (microns)', fontsize=16)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
    raster_fig.savefig(os.path.join(figname))

def combineMemFiles(datadirs, file):
    """combine membrane potential files from fragmented runs"""
    ## load first file 
    with open(os.path.join(datadirs[0],file), 'rb') as fileObj:
        data = pickle.load(fileObj)
    ## convert voltages to lists
    for ind, v in enumerate(data[0]):
        data[0][ind] = list(v)
    ## load the rest of them 
    for datadir in datadirs[1:]:
        with open(os.path.join(datadir, file), 'rb') as fileObj:
            data0 = pickle.load(fileObj)
        for ind, v in enumerate(data0[0]):
            data[0][ind].extend(list(v))
        data[2].extend(data0[2])
    return data 

def getComboRaster(datadirs, position='center', uniform=False):
    """get raster in dictionary for fragmented simulation runs"""
    raster = {}
    for ind, datadir in enumerate(datadirs):
        if ind == 0:
            files = os.listdir(datadir)
            mem_files = [file for file in files if (file.startswith(position + 'membrane'))]
        for file in mem_files:
            with open(os.path.join(datadir,file), 'rb') as fileObj:
                    data = pickle.load(fileObj)
            if ind == 0:
                for v, pos in zip(data[0],data[1]):
                    pks, _ = find_peaks(v.as_numpy(), 0)
                    if len(pks):
                        if uniform:
                            r = (pos[0]**2 + pos[1]**2 + pos[2]**2)**(0.5)
                        else:
                            r = (pos[0]**2 + pos[1]**2)**(0.5)
                        raster[r] = [data[2][ind] for ind in pks]
            else:
                for v, pos in zip(data[0],data[1]):
                    pks, _ = find_peaks(v.as_numpy(), 0)
                    if len(pks):
                        if uniform:
                            r = (pos[0]**2 + pos[1]**2 + pos[2]**2)**(0.5)
                        else:
                            r = (pos[0]**2 + pos[1]**2)**(0.5)
                        if r in raster.keys():
                            raster[r].extend([data[2][ind] for ind in pks])
                        else:
                            raster[r] = [data[2][ind] for ind in pks]
    return raster 

def getSpkMetrics(datadir, includerecs=False, position='center', uniform=False):
    """returns basic spiking metrics from single sim"""
    raster = getRaster(datadir, includerecs=includerecs, position=position, uniform=uniform)
    spkMetrics = {}
    for k in raster.keys():
        spkMetrics[k] = {}
        t0 = raster[k][0]
        t1 = raster[k][-1]
        spkMetrics[k]['t2firstSpk'] = t0 / 1000
        spkMetrics[k]['nSpks'] = len(raster[k])
        if len(raster[k]) > 1:
            spkMetrics[k]['spkFreq'] = len(raster[k]) / ((t1-t0) / 1000)
            spkMetrics[k]['spkDur'] = (t1-t0) / 1000
        else:
            spkMetrics[k]['spkFreq'] = 1
            spkMetrics[k]['spkDur'] = 0
    return spkMetrics
        

def getRaster(datadir, includerecs=False, position='center', uniform=False):
    """Returns spike raster from a simulation as a dictionary (keys - radial distance)"""
    raster = {}
    files = os.listdir(datadir)
    if position:
        mem_files = [file for file in files if (file.startswith(position + 'membrane'))]
    else:
        mem_files = [file for file in files if (file.startswith('membrane'))]
    for file in mem_files:
        with open(os.path.join(datadir,file), 'rb') as fileObj:
            data = pickle.load(fileObj)
        for v, pos in zip(data[0],data[1]):
            pks, _ = find_peaks(v.as_numpy(), 0)
            if len(pks):
                if uniform:
                    r = (pos[0]**2 + pos[1]**2 + pos[2]**2)**(0.5)
                else:
                    r = (pos[0]**2 + pos[1]**2)**(0.5)
                raster[r] = [data[2][ind] for ind in pks]
    ## also get spikes from recs.pkl 
    if includerecs:
        with open(datadir+'recs.pkl', 'rb') as fileObj:
            data = pickle.load(fileObj)
        for pos, v in zip(data['pos'], data['v']):
            pks, _ = find_peaks(v.as_numpy(), 0)
            if len(pks):
                r = (pos[0]**2 + pos[1]**2 + pos[2]**2)**(0.5)
                raster[r] = [data['t'][ind] for ind in pks]
    return raster

def xyOfSpikeTime(datadir, position='center'):
    """returns dictionary where keys are spike times and values are position data for the spiking cell(s)"""
    if isinstance(datadir, list):
        files = os.listdir(datadir[0])
    else:
        files = os.listdir(datadir)
    mem_files = [file for file in files if (file.startswith(position + 'membrane'))]
    posBySpkTime = {}
    for file in mem_files:
        if isinstance(datadir, list):
            data = combineMemFiles(datadir, file)
        else:
            with open(os.path.join(datadir,file), 'rb') as fileObj:
                data = pickle.load(fileObj)
        for v, pos in zip(data[0],data[1]):
            if isinstance(v, list):
                pks, _ = find_peaks(v, 0)
            else:
                pks, _ = find_peaks(v.as_numpy(), 0)
            if len(pks):                
                for ind in pks:
                    t = int(data[2][ind])
                    if t in posBySpkTime.keys():
                        posBySpkTime[t]['x'].append(pos[0])
                        posBySpkTime[t]['y'].append(pos[1])
                        posBySpkTime[t]['z'].append(pos[2])
                    else:
                        posBySpkTime[t] = {}
                        posBySpkTime[t]['x'] = [pos[0]]
                        posBySpkTime[t]['y'] = [pos[1]]
                        posBySpkTime[t]['z'] = [pos[2]]
    return posBySpkTime

def centerVsPeriphKspeed(datadir, dur, rmax=600):
    """returns K+ wave speeds for both the core of the slice and the periphery"""
    k_files = ['k_'+str(i)+'.npy' for i in range(int(dur*1000)) if (i%100)==0]
    time = []
    wave_pos_core = []
    wave_pos_periph = [] 
    for k_file in k_files:
        time.append(float(k_file.split('.')[0].split('_')[1]) / 1000)
        with open(datadir+k_file, 'rb') as fileObj:
            data = np.load(fileObj)
        core = int(data.shape[2]/2)
        periph = int(data.shape[2]-2)
        inds_core = np.argwhere(data[:,:,core] > 15.0)
        if len(inds_core):
            ds_core = []
            for ind in inds_core:
                ds_core.append(((ind[0]-20)**2 + (ind[1]-20)**2)**(1/2) * 500/20)
            wave_pos_core.append(np.max(ds_core))
        else:
            wave_pos_core.append(0)
        inds_periph = np.argwhere(data[:,:,periph] > 15.0)
        if len(inds_periph):
            ds_periph = []
            for ind in inds_periph:
                ds_periph.append(((ind[0]-20)**2 + (ind[1]-20)**2)**(1/2) * 500/20)
            wave_pos_periph.append(np.max(ds_periph))
        else:
            wave_pos_periph.append(0)
        wave_pos_cut = [w for w in wave_pos_core if w > rmax]
        t_cut = [t for t, w in zip(time, wave_pos_core) if w > rmax]
        if len(wave_pos_cut):
            m = (wave_pos_cut[0] - wave_pos_core[0]) / (t_cut[0] - time[0])
        elif np.max(wave_pos_core) > 200:
            m = (np.max(wave_pos_core) - wave_pos_core[0]) / (time[np.argmax(wave_pos_core)]-time[0])
        else:
            m = 0
        core_speed = m / 16.667
        wave_pos_cut = [w for w in wave_pos_periph if w > rmax]
        t_cut = [t for t, w in zip(time, wave_pos_periph) if w > rmax]
        if len(wave_pos_cut):
            m = (wave_pos_cut[0] - wave_pos_periph[0]) / (t_cut[0] - time[0])
        elif np.max(wave_pos_core) > 200:
            m = (np.max(wave_pos_periph) - wave_pos_periph[0]) / (time[np.argmax(wave_pos_periph)]-time[0])
        else:
            m = 0
        periph_speed = m / 16.667
    return core_speed, periph_speed



def allSpeciesMov(datadir, outpath, vmins, vmaxes, figname, condition='Perfused', dur=10, extent=None, fps=40):
    """"Generates an mp4 video with heatmaps for K+, Cl-, Na+, and O2 overlaid with spiking data"""
    try:
        os.mkdir(outpath)
    except:
        pass
    specs = ['k', 'cl', 'na', 'o2']
    k_files = [specs[0]+'_'+str(i)+'.npy' for i in range(int(dur*1000)) if (i%100)==0]    
    cl_files = [specs[1]+'_'+str(i)+'.npy' for i in range(int(dur*1000)) if (i%100)==0]    
    na_files = [specs[2]+'_'+str(i)+'.npy' for i in range(int(dur*1000)) if (i%100)==0]    
    o2_files = [specs[3]+'_'+str(i)+'.npy' for i in range(int(dur*1000)) if (i%100)==0]    
    for k_file, cl_file, na_file, o2_file in zip(k_files, cl_files, na_files, o2_files):
        t = int(k_file.split('.')[0].split('_')[1])
        ttl = 't = ' + str(float(k_file.split('.')[0].split('_')[1]) / 1000) + ' s'
        
        # fig, big_axes = plt.subplots( figsize=(14, 7.6) , nrows=1, ncols=1, sharey=True)
        fig = plt.figure(figsize=(12,7.6))
        # # for row, big_ax in enumerate(zip(big_axes,['Perfused']), start=1):
        # big_axes.set_title(ttl, fontsize=22)
        # # Turn off axis lines and ticks of the big subplot 
        # # obs alpha is 0 in RGBA string!
        # big_axes.tick_params(labelcolor=(1.,1.,1., 0.0), top='off', bottom='off', left='off', right='off')
        # # removes the white frame
        # # big_axes._frameon = False
        # plt.tight_layout()
        fig.text(.45, 0.825, ttl, fontsize=20)
        fig.text(0.45, 0.9, condition, fontsize=20)
        posBySpkTime = xyOfSpikeTime(datadir)
        spkTimes = [key for key in posBySpkTime if (t-50 < key <= t+50)]

        ax1 = fig.add_subplot(141)
        data = np.load(datadir+k_file)
        im = plt.imshow(data.mean(2), vmin=vmins[0], vmax=vmaxes[0], interpolation='nearest', origin='lower', extent=extent)
        plt.colorbar(im, fraction=0.046, pad=0.04)
        if len(spkTimes):
            for spkTime in spkTimes:
                plt.plot(posBySpkTime[spkTime]['x'], posBySpkTime[spkTime]['y'], 'w*', markerSize=5)
        plt.title(r'[K$^+$]$_{ECS}$ ', fontsize=20)

        ax2 = fig.add_subplot(142)
        data = np.load(datadir+cl_file)
        im = plt.imshow(data.mean(2), vmin=vmins[1], vmax=vmaxes[1], interpolation='nearest', origin='lower', extent=extent)
        plt.colorbar(im, fraction=0.046, pad=0.04)
        if len(spkTimes):
            for spkTime in spkTimes:
                plt.plot(posBySpkTime[spkTime]['x'], posBySpkTime[spkTime]['y'], 'w*', markerSize=5)
        plt.title(r'[Cl$^-$]$_{ECS}$ ', fontsize=20)

        ax3 = fig.add_subplot(143)
        data = np.load(datadir+na_file)
        im = plt.imshow(data.mean(2), vmin=vmins[2], vmax=vmaxes[2], interpolation='nearest', origin='lower', extent=extent)
        plt.colorbar(im, fraction=0.046, pad=0.04)
        if len(spkTimes):
            for spkTime in spkTimes:
                plt.plot(posBySpkTime[spkTime]['x'], posBySpkTime[spkTime]['y'], 'w*', markerSize=5)
        plt.title(r'[Na$^+$]$_{ECS}$ ', fontsize=20)

        ax4 = fig.add_subplot(144)
        data = np.load(datadir+o2_file)
        im = plt.imshow(data.mean(2), vmin=vmins[3], vmax=vmaxes[3], interpolation='nearest', origin='lower', extent=extent)
        plt.colorbar(im, fraction=0.046, pad=0.04)
        if len(spkTimes):
            for spkTime in spkTimes:
                plt.plot(posBySpkTime[spkTime]['x'], posBySpkTime[spkTime]['y'], 'w*', markerSize=5)
        plt.title(r'[O$_2$]$_{ECS}$ ', fontsize=20)
        plt.tight_layout()
        
        # plt.tight_layout()
        fig.savefig(outpath + k_file[:-4] + '.png')
        plt.close()

    times = []
    filenames = os.listdir(path=outpath)
    for file in filenames: 
        times.append(float(file.split('_')[-1].split('.')[0])) 
    inds = np.argsort(times)
    filenames_sort = [filenames[i] for i in inds]
    imagesc = []
    for filename in filenames_sort: 
        imagesc.append(imageio.imread(outpath+filename)) 
    imageio.mimsave(figname, imagesc)
        
        
def spkPlusMovie(dirs, outpath, species='k', vmin=3.5, vmax=40, figname='kmovie', dur=10, extent=None, titles=None):
    """generates gif(s) of K+ diffusion overlaid with spiking data"""
    try:
        os.mkdir(outpath)
    except:
        pass

    files = [species+'_'+str(i)+'.npy' for i in range(int(dur*1000)) if (i%100)==0]    
    for filename in files:
        plt.figure(figsize=(8,10))
        for ind in range(len(dirs)):
            # xy positions of spikes
            posBySpkTime = xyOfSpikeTime(dirs[ind])
            t = int(filename.split('.')[0].split('_')[1])
            plt.subplot(len(dirs), 1, ind+1)
            if isinstance(dirs[ind], list):
                for i in range(len(dirs[ind])):
                    try:
                        data = np.load(dirs[ind][i]+filename)
                    except:
                        pass
            else:
                data = np.load(dirs[ind]+filename)
            if extent:
                im = plt.imshow(data.mean(2), vmin=vmin, vmax=vmax, interpolation='nearest', origin='lower', extent=extent)          
            else:
                im = plt.imshow(data.mean(2), vmin=vmin, vmax=vmax, interpolation='nearest', origin='lower')
            cbar = plt.colorbar(im, fraction=0.046, pad=0.04)
            cbar.set_label(r'[K$^+$] (mM)', fontsize=16, rotation=270)
            # plt.title(dirs[ind][:-1])
            spkTimes = [key for key in posBySpkTime if (t-50 < key <= t+50)]
            if len(spkTimes):
                for spkTime in spkTimes:
                    plt.plot(posBySpkTime[spkTime]['x'], posBySpkTime[spkTime]['y'], 'w*', markerSize=5)
            if titles:
                ttl = titles[ind] + ': ' + str(float(filename.split('.')[0].split('_')[1]) / 1000) + ' s'
            else:
                ttl = str(float(filename.split('.')[0].split('_')[1]) / 1000) + ' s'
            plt.xlabel(r'X-Position ($\mu$m)', fontsize=16)
            plt.ylabel(r'Y-Position ($\mu$m)', fontsize=16)
            plt.xticks(fontsize=14)
            plt.yticks(fontsize=14)
            plt.title(ttl, fontsize=20)
            plt.tight_layout()
            plt.savefig(outpath + filename[:-4] + '.png')
            plt.close()

    times = []
    filenames = os.listdir(path=outpath)
    for file in filenames: 
        times.append(float(file.split('_')[-1].split('.')[0])) 
    inds = np.argsort(times)
    filenames_sort = [filenames[i] for i in inds]
    imagesc = []
    for filename in filenames_sort: 
        imagesc.append(imageio.imread(outpath+filename)) 
    imageio.mimsave(figname + '.gif', imagesc)

def spikeFreqHeatMap(datadir, rmax=300, spatialBin=25, tempBin=25, noverlap=0, dur=10, sbplt=None, spkCmax=40):
    """Generates heatmaps where x is time, y is radial distance from center, color is average membrane potential
    within spatiotemporal bin. Overlaid with raster plot."""
    # rmax: maximum distance from center, spatialBin: in microns, tempBin: in ms, dur: duration in seconds 
    ## generate raster with distance keys
    if isinstance(datadir, list):
        raster = getComboRaster(datadir)
    else:
        raster = getRaster(datadir)

    ## allocate binned raster 
    # spatial_bins = list(range(spatialBin,rmax+1,spatialBin))
    spatial_bins = list(range(spatialBin,rmax+1,spatialBin-noverlap))
    binned_raster = {}
    for spatial_bin in spatial_bins:
        binned_raster[spatial_bin] = []

    ## find raster keys of each spatial bin 
    for key in raster.keys():
        ind = 0
        while key > spatial_bins[ind]:
            ind = ind + 1
        binned_raster[spatial_bins[ind]].append(raster[key])

    ## find avg spike rate for each spatial bin 
    temp_bins = list(range(tempBin, int(dur*1000)+1, tempBin))
    binned_binned = np.zeros((len(spatial_bins),len(temp_bins)))
    for row, key in enumerate(spatial_bins):
        for col, temp_bin in enumerate(temp_bins):
            spk_freqs = []
            for spk_times in binned_raster[key]:
                spk_times_inbin = [t for t in spk_times if (temp_bin-tempBin) < t < temp_bin] 
                spk_freqs.append(len(spk_times_inbin) / (tempBin / 1000))
            binned_binned[row][col] = np.mean(spk_freqs)
            if np.isnan(binned_binned[row][col]):
                binned_binned[row][col] = 0

    if sbplt:
        ## plotting 
        plt.subplot(sbplt)
        # ex = (0, dur/60, 0, rmax)
        ex = (0, dur, 0, rmax)
        hmap = plt.imshow(binned_binned, origin='lower', interpolation='spline16', 
                    extent=ex, cmap='inferno', aspect='auto', vmax=40)
        cbar = plt.colorbar(hmap)
        cbar.set_label(r'Avg. Spike Frequency (Hz)', fontsize=14)
        plt.clim(0,spkCmax)
        # plt.xlabel('Time (min)', fontsize=14)
        plt.xlabel('Time (s)', fontsize=14)
        plt.ylabel(r'Radial Distance ($\mu$m)', fontsize=14)
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        # plt.xscale('log')
        return hmap, cbar
    else:
        return binned_binned

def avgVmembHeatMap(datadir, figname='vmemb_heat_map.png', rmax=300, spatialBin=25, tempBin=50, dur=10, position='center'):
    """Generates heatmaps where x is time, y is radial distance from center, color is average membrane potential."""
    # ignores spikes. rmax: maximum distance from center, spatialBin: in microns, tempBin: in ms, dur: duration in seconds 
    ## find membrane potential files 
    vmemb_by_pos = {}
    files = os.listdir(datadir)
    mem_files = [file for file in files if (file.startswith(position + 'membrane')) and (len(file.split('_')) == 3)]
    spatial_bins = list(range(spatialBin,rmax+1,spatialBin))

    ## generate dictionary where keys are position, values are binned Vmemb
    for file in mem_files:
        print(str(file))
        if isinstance(datadir, list):
            data = combineMemFiles(datadir, file)
        else:
            with open(os.path.join(datadir,file), 'rb') as fileObj:
                data = pickle.load(fileObj)
        t = data[2]
        for v, pos in zip(data[0],data[1]):
            v = v.as_numpy()[:-1]
            tbin = tempBin 
            r = (pos[0]**2 + pos[1]**2 + pos[2]**2)**(0.5)
            vmemb_by_pos[r] = []
            while tbin <= dur * 1000:
                vbin = [V for V, T in zip(v,t) if tbin-tempBin < T < tbin]
                vmemb_by_pos[r].append(np.mean(vbin))
                tbin = tbin + tempBin

    ## spatially binned lists of mean vmembs 
    binned_vmembs = {}
    for space_bin in spatial_bins:
        binned_vmembs[space_bin] = []
        for r in vmemb_by_pos.keys():
            if space_bin-spatialBin < r < space_bin:
                binned_vmembs[space_bin].append(vmemb_by_pos[r])

    ## matrix with averaged vmembs for each spatial bin 
    temp_bins = list(range(tempBin, int(dur*1000)+1, tempBin))
    binned_binned = np.zeros((len(spatial_bins), len(temp_bins)))
    for row, key in enumerate(spatial_bins):
        binned_binned[row][:] = np.mean(np.array(binned_vmembs[key]), axis=0)

    ## plotting 
    # plt.ion()
    # plt.figure(figsize=(16,8))
    # plt.imshow(binned_binned, interpolation='nearest', origin='lower', aspect='auto', cmap='inferno')
    # plt.colorbar()
    # plt.savefig(figname)
    return binned_binned

def tempBinning(input_data):
    # temporal binning for Vmemb plots
    datadir, file, tempBin, dur, uniform, return_dict = input_data[0], input_data[1], input_data[2], input_data[3], input_data[4], input_data[5]
    print(str(file))
    if isinstance(datadir, list):
        data = combineMemFiles(datadir, file)
    else:
        with open(os.path.join(datadir,file), 'rb') as fileObj:
            data = pickle.load(fileObj)
    t = data[2]
    temp_dict = {}
    for v, pos in zip(data[0],data[1]):
        if isinstance(v, list):
            v = v[:-1]
        else:
            v = v.as_numpy()[:-1]
        tbin = tempBin 
        if uniform:
            r = (pos[0]**2 + pos[1]**2 + pos[2]**2)**(0.5)
        else:
            r = (pos[0]**2 + pos[1]**2)**(0.5)
        temp_dict[r] = []
        while tbin <= dur * 1000:
            vbin = [V for V, T in zip(v,t) if tbin-tempBin < T < tbin]
            temp_dict[r].append(np.mean(vbin))
            tbin = tbin + tempBin
    for key in temp_dict.keys():
        return_dict[key] = temp_dict[key]
    
def vmembHeatMapParallel(datadir, rmax=300, spatialBin=25, noverlap=10, tempBin=50, dur=10, poolSize=1, position='center', uniform=False):
    # Uses multiprocessing to run temporal binnging then gen heat map
    if isinstance(datadir, list):
        files = os.listdir(datadir[0])
    else:
        files = os.listdir(datadir)

    mem_files = [file for file in files if (file.startswith(position + 'membrane')) and (len(file.split('_')) == 3)]
    spatial_bins = list(range(spatialBin,rmax+1,spatialBin-noverlap))

    ## parallel temporal binning 
    manager = multiprocessing.Manager()
    vmemb_by_pos = manager.dict()
    data = []
    for mem_file in mem_files:
        data.append([datadir, mem_file, tempBin, dur, uniform, vmemb_by_pos])
    data = tuple(data)

    p = multiprocessing.Pool(poolSize)
    p.map(tempBinning, data)

    ## spatially binned lists of mean vmembs 
    binned_vmembs = {}
    for space_bin in spatial_bins:
        print(space_bin)
        binned_vmembs[space_bin] = []
        for r in vmemb_by_pos.keys():
            if space_bin-spatialBin < r < space_bin:
                binned_vmembs[space_bin].append(vmemb_by_pos[r])

    ## matrix with averaged vmembs for each spatial bin 
    temp_bins = list(range(tempBin, int(dur*1000)+1, tempBin))
    binned_binned = np.zeros((len(spatial_bins), len(temp_bins)))
    for row, key in enumerate(spatial_bins):
        print(row)
        binned_binned[row][:] = np.mean(np.array(binned_vmembs[key]), axis=0)

    return temp_bins, spatial_bins, binned_binned, vmemb_by_pos

def scatterRheobase(datadir, starts, amps, position='center'):
    # scatter plots for rheobase stims, currently unused
    rheo = []
    radii = []
    rheo_starts = []
    files = os.listdir(datadir)
    mem_files = [file for file in files if (file.startswith(position + 'membrane'))]
    t0 = starts[0]
    for file in mem_files:
        fileObj = open(os.path.join(datadir,file), 'rb')
        data = pickle.load(fileObj)
        fileObj.close()
        for v, pos in zip(data[0],data[1]):
            t = data[2]
            v = v.as_numpy()[:-1]
            v = [V for V, T in zip(v, t) if T >= t0-200]
            t = [T for T in t if T >= t0-200]
            pks, _ = find_peaks(v, 0)
            radii.append((pos[0]**2 + pos[1]**2 + pos[2]**2)**(0.5))
            if len(pks):
                spk_time = t[pks[0]]
                start_ind = np.argmin(np.square(np.subtract(starts,spk_time)))
                rheo.append(amps[start_ind])
                rheo_starts.append(starts[start_ind])
            else:
                rheo.append(0)
                rheo_starts.append(0)
    out = {'rheobase' : rheo, 'radius' : radii, 'rheo_startTime' : rheo_starts}
    return out
    
def comboVmembPlot(datadir, duration, fig, sbplt, vmin, vmax, spatialbin=40, noverlap=1, poolsize=1, rmax=300, position='center', uniform=False, top=False, left=False):
    # combine heatmap of Vmemb with white-out where spikes occur 
    ## heat map 
    # if isinstance(datadir, list):
    #     raster = getComboRaster(datadir)
    # else:
    raster = getRaster(datadir, position=position, uniform=uniform)
    temp_bins, spatial_bins, vmemb, vm_by_pos = vmembHeatMapParallel(datadir, 
        poolSize=poolsize, spatialBin=spatialbin, dur=duration, noverlap=noverlap, position=position,
        uniform=uniform)
    # if sbplt:
    #     plt.subplot(sbplt)
    # else:
    #     plt.figure(figsize=(8,10))
    # ex = (0, duration/60, 0, rmax)
    ex = (0, duration, 0, rmax)
    hmap = sbplt.imshow(vmemb, origin='lower', interpolation='spline16', 
                    extent=ex, cmap='inferno', aspect='auto', vmin=vmin, vmax=vmax)
    cbar = fig.colorbar(hmap, ax=sbplt)
    if not left:
        cbar.set_label(r'Avg. V$_{memb}$ (mV)', fontsize=14)
    cbar.ax.tick_params(labelsize=12)

    ## white-out spikes 
    for key in raster.keys():
        # plt.plot(np.divide(raster[key],60000), [key for i in range(len(raster[key]))], 'w.')
        sbplt.plot(np.divide(raster[key],1000), [key for i in range(len(raster[key]))], 'w.')
    # plt.xlabel('Time (min)', fontsize=14)
    if not top:
        sbplt.set_xlabel('Time (s)', fontsize=16)
    if left:
        sbplt.set_ylabel(r'Radial Distance ($\mu$m)', fontsize=16)
    plt.setp(sbplt.get_xticklabels(), fontsize=14)
    plt.setp(sbplt.get_yticklabels(), fontsize=14)

    return hmap, cbar 

def vmembSpkFreqSbPlts(datadir, duration, figname=None, spatialbin=40, tempBin=25, noverlap=1, poolsize=1, rmax=300, sbplts=[211,212], logscale=False, tend=6, spkCmax=40):
    # single figure with spike frequency heat map and vmemb heat map with spikes whited-out
    hmap1, cbar1 = spikeFreqHeatMap(datadir, 
                                    rmax=rmax, 
                                    spatialBin=spatialbin, 
                                    tempBin=tempBin, 
                                    dur=duration, 
                                    noverlap=noverlap,
                                    sbplt=sbplts[0],
                                    spkCmax=spkCmax)
    
    hmap2, cbar2 = comboVmembPlot(datadir, duration, 
                                    spatialbin=spatialbin, 
                                    noverlap=noverlap, 
                                    poolsize=poolsize, 
                                    rmax=rmax, 
                                    sbplt=sbplts[1])

    if logscale:
        plt.subplot(sbplts[0])
        plt.xscale('log')
        plt.subplot(sbplts[1])
        plt.xscale('log')

    plt.subplot(sbplts[0])
    plt.xlim(0.01, tend)
    plt.subplot(sbplts[1])
    plt.xlim(0.01, tend)

    if figname:
        plt.tight_layout()
        plt.savefig(figname)
    else:
        return hmap1, hmap2, cbar1, cbar2

def compositeFig(datadir, figname, dif_files, duration=20, spatialBin=75, tempBin=50, noverlap=10, rmax=1250, poolsize=40, sbplts=[223, 224], useAgg=True, logscale=False, extent=(-1000,1000,-1000,1000), figsize=(16,10), tend=6):
    if useAgg:
            import matplotlib; matplotlib.use('Agg')

    plt.figure(figsize=figsize)

    for filename, i in zip(dif_files, range(4)):
        data = np.load(datadir+filename)
        plt.subplot(2,4,i+1)
        hmap = plt.imshow(data.mean(2), vmin=3.5, vmax=40, interpolation='nearest', origin='lower', extent=extent, cmap='inferno')
        t = str(float(filename.split('.')[0].split('_')[1]) / 1000) + ' s'
        plt.title(t, fontsize=18)
        plt.xlabel(r'X Position ($\mu$m)', fontsize=14)
        plt.xticks(fontsize=12)
        plt.ylabel(r'Y Position ($\mu$m)', fontsize=14)
        plt.yticks(fontsize=12)
    # cbar = plt.colorbar(hmap)
    # cbar.set_label(r'[K$^+$] (mM)', fontsize=14)

    vmembSpkFreqSbPlts(datadir, duration, figname=figname, tend=tend, spatialbin=spatialBin, tempBin=tempBin, noverlap=noverlap, poolsize=poolsize, rmax=rmax, sbplts=sbplts, logscale=logscale)

def sinkVsSource(datadir, recNum=None):
    """epoching based on when cells act as K+ sinks"""
    if recNum:
        filename = 'recs' + str(recNum) + '.pkl'
    else:
        filename = 'recs.pkl'

    if isinstance(datadir, list):
        data = combineRecs(datadir)
    else:
        with open(datadir+filename, 'rb') as fileObj:
            data = pickle.load(fileObj)
    sink_periods = []
    radius = []
    for i, k in enumerate(data['ki']):
        rising_t = [t for t,d in zip(list(data['t'])[1:], 
                        np.diff(list(k))) if d > 0]
        # sink_periods.append(rising_t)
        if len(rising_t) > 1:
            sink_periods.append(rising_t[-1] - rising_t[0])
        else:
            sink_periods.append(0)
        radius.append((data['pos'][i][0] ** 2 + data['pos'][i][1] ** 2 + data['pos'][i][2] ** 2)**(0.5))
    return sink_periods, radius

def sinkVsSourceFig(datadir, iss=[0, 7, 15], recNum=None):
    """basic plotting for epochs where cells act as K+ sinks"""
    if recNum:
        filename = 'recs' + str(recNum) + '.pkl'
    else:
        filename = 'recs.pkl'

    if isinstance(datadir, list):
        data = combineRecs(datadir)
    else:
        with open(datadir+filename, 'rb') as fileObj:
            data = pickle.load(fileObj)

    fig, axs = plt.subplots(6,1, sharex=True)
    fig.set_figwidth(10)
    fig.set_figheight(10)

    for ind, i in enumerate(iss):
        rising_t = [t for t,d in zip(list(data['t'])[1:], 
                        np.diff(list(data['ki'][i]))) if d > 0]
        max_rise_t = np.max(rising_t)
        predrop_t = [t for t in list(data['t']) if t <= max_rise_t]
        predrop_ki = [ki for t, ki in zip(list(data['t']), list(data['ki'][i]))
                        if t <= max_rise_t]
        predrop_o2 = [o2 for t, o2 in zip(list(data['t']), list(data['o2'][i]))
                        if t <= max_rise_t]
        predrop_nai = [nai for t, nai in zip(list(data['t']), list(data['nai'][i]))
                        if t <= max_rise_t]
        predrop_cli = [cli for t, cli in zip(list(data['t']), list(data['cli'][i]))
                        if t <= max_rise_t]
        predrop_v = [v for t, v in zip(list(data['t']), list(data['v'][i]))
                        if t <= max_rise_t]
        predrop_ko = [ko for t, ko in zip(list(data['t']), list(data['ko'][i]))
                        if t <= max_rise_t]
        l = r'%s $\mu$m' % str(np.round((data['pos'][i][0] ** 2 + data['pos'][i][1] ** 2 + data['pos'][i][2] ** 2)**(0.5),1))
        axs[0].plot(np.divide(predrop_t, np.max(predrop_t)), predrop_v, label=l)
        axs[1].plot(np.divide(predrop_t, np.max(predrop_t)), predrop_ko, label=l)
        axs[2].plot(np.divide(predrop_t, np.max(predrop_t)), predrop_ki, label=l)
        axs[3].plot(np.divide(predrop_t, np.max(predrop_t)), predrop_o2, label=l)
        axs[4].plot(np.divide(predrop_t, np.max(predrop_t)), predrop_nai, label=l)
        axs[5].plot(np.divide(predrop_t, np.max(predrop_t)), predrop_cli, label=l)
    
    leg = axs[0].legend(title='Radial Position', fontsize=8, bbox_to_anchor=(-0.275, 1.05))
    plt.setp(leg.get_title(), fontsize=10)
    # axs[0].set_title(r'Membrane Potential', fontsize=18)
    axs[0].set_ylabel(r'V$_{memb}$ (mV)', fontsize=12)
    plt.setp(axs[0].get_xticklabels(), fontsize=10)
    plt.setp(axs[0].get_yticklabels(), fontsize=10)
    # axs[1].set_title(r'[K$^+$]$_o$', fontsize=18)
    axs[1].set_ylabel(r'[K$^+$]$_o$ (mM)', fontsize=12)
    plt.setp(axs[1].get_xticklabels(), fontsize=10)
    plt.setp(axs[1].get_yticklabels(), fontsize=10)
    # axs[2].set_title(r'[K$^+$]$_i$', fontsize=18)
    axs[2].set_ylabel(r'[K$^+$]$_i$ (mM)', fontsize=12)
    plt.setp(axs[2].get_xticklabels(), fontsize=10)
    plt.setp(axs[2].get_yticklabels(), fontsize=10)
    # axs[3].set_title(r'[O$_2$]$_{ECS}$', fontsize=18)
    axs[3].set_ylabel(r'[O$_2$]$_{ECS}$ (mM)', fontsize=12)
    plt.setp(axs[3].get_xticklabels(), fontsize=10)
    plt.setp(axs[3].get_yticklabels(), fontsize=10)
    # axs[4].set_title(r'[Na$^+$]$_i$', fontsize=18)
    axs[4].set_ylabel(r'[Na$^+$]$_i$ (mM)', fontsize=12)
    plt.setp(axs[4].get_xticklabels(), fontsize=10)
    plt.setp(axs[4].get_yticklabels(), fontsize=10)
    # axs[5].set_title(r'[Cl$^-$]$_i$', fontsize=18)
    axs[5].set_ylabel(r'[Cl$^-$]$_i$ (mM)', fontsize=12)
    axs[5].set_xlabel(r'Normalized Time', fontsize=12)
    plt.setp(axs[5].get_xticklabels(), fontsize=10)
    plt.setp(axs[5].get_yticklabels(), fontsize=10)

def traceExamples(datadir, figname, iss=[0, 7, 15], recNum=None):
    """Function for plotting Vmemb, as well as ion and o2 concentration, for selected (iss) recorded
    neurons"""
    if recNum:
        filename = 'recs' + str(recNum) + '.pkl'
    else:
        filename = 'recs.pkl'

    if isinstance(datadir, list):
        data = combineRecs(datadir)
    else:
        with open(datadir+filename, 'rb') as fileObj:
            data = pickle.load(fileObj)
    # fig = plt.figure(figsize=(18,9))
    fig, axs = plt.subplots(2,4)
    fig.set_figheight(9)
    fig.set_figwidth(18)
    for i in iss:
        l = r'%s $\mu$m' % str(np.round((data['pos'][i][0] ** 2 + data['pos'][i][1] ** 2 + data['pos'][i][2] ** 2)**(0.5),1))
        axs[0][0].plot(np.divide(data['t'],1000), data['v'][i], label=l)
        axs[1][0].plot(np.divide(data['t'],1000), data['o2'][i])
        axs[0][1].plot(np.divide(data['t'],1000), data['ki'][i])
        axs[1][1].plot(np.divide(data['t'],1000), data['ko'][i])
        axs[0][2].plot(np.divide(data['t'],1000), data['nai'][i])
        axs[1][2].plot(np.divide(data['t'],1000), data['nao'][i])
        axs[0][3].plot(np.divide(data['t'],1000), data['cli'][i])
        axs[1][3].plot(np.divide(data['t'],1000), data['clo'][i])
    
    leg = axs[0][0].legend(title='Radial Position', fontsize=11, bbox_to_anchor=(-0.275, 1.05))
    plt.setp(leg.get_title(), fontsize=15)
    axs[0][0].set_ylabel('Membrane Potential (mV)', fontsize=16)
    plt.setp(axs[0][0].get_xticklabels(), fontsize=14)
    plt.setp(axs[0][0].get_yticklabels(), fontsize=14)
    axs[0][0].text(-0.15, 1.0, 'A)', transform=axs[0][0].transAxes,
        fontsize=16, fontweight='bold', va='top', ha='right')

    axs[1][0].set_ylabel(r'Extracellular [O$_{2}$] (mM)', fontsize=16)
    plt.setp(axs[1][0].get_xticklabels(), fontsize=14)
    plt.setp(axs[1][0].get_yticklabels(), fontsize=14)
    axs[1][0].text(-0.15, 1., 'E)', transform=axs[1][0].transAxes,
        fontsize=16, fontweight='bold', va='top', ha='right')
    
    axs[0][1].set_ylabel(r'Intracellular [K$^{+}$] (mM)', fontsize=16)
    plt.setp(axs[0][1].get_xticklabels(), fontsize=14)
    plt.setp(axs[0][1].get_yticklabels(), fontsize=14)
    axs[0][1].text(-0.15, 1., 'B)', transform=axs[0][1].transAxes,
        fontsize=18, fontweight='bold', va='top', ha='right')
    
    axs[1][1].set_ylabel(r'Extracellular [K$^{+}$] (mM)', fontsize=16)
    plt.setp(axs[1][1].get_xticklabels(), fontsize=14)
    plt.setp(axs[1][1].get_yticklabels(), fontsize=14)
    axs[1][1].text(-0.15, 1.0, 'F)', transform=axs[1][1].transAxes,
      fontsize=18, fontweight='bold', va='top', ha='right')
    
    axs[0][2].set_ylabel(r'Intracellular [Na$^{+}$] (mM)', fontsize=16)
    plt.setp(axs[0][2].get_xticklabels(), fontsize=14)
    plt.setp(axs[0][2].get_yticklabels(), fontsize=14)
    axs[0][2].text(-0.15, 1.0, 'C)', transform=axs[0][2].transAxes,
      fontsize=18, fontweight='bold', va='top', ha='right')

    axs[1][2].set_ylabel(r'Extracellular [Na$^{+}$] (mM)', fontsize=16)
    plt.setp(axs[1][2].get_xticklabels(), fontsize=14)
    plt.setp(axs[1][2].get_yticklabels(), fontsize=14)
    axs[1][2].text(-0.15, 1.0, 'G)', transform=axs[1][2].transAxes,
      fontsize=18, fontweight='bold', va='top', ha='right')
    
    axs[0][3].set_ylabel(r'Intracellular [Cl$^{-}$] (mM)', fontsize=16)
    plt.setp(axs[0][3].get_xticklabels(), fontsize=14)
    plt.setp(axs[0][3].get_yticklabels(), fontsize=14)
    axs[0][3].text(-0.15, 1.0, 'D)', transform=axs[0][3].transAxes,
      fontsize=18, fontweight='bold', va='top', ha='right')
    
    axs[1][3].set_ylabel(r'Extracellular [Cl$^{-}$] (mM)', fontsize=16)
    plt.setp(axs[1][3].get_xticklabels(), fontsize=14)
    plt.setp(axs[1][3].get_yticklabels(), fontsize=14)
    axs[1][3].text(-0.15, 1.0, 'H)', transform=axs[1][3].transAxes,
      fontsize=18, fontweight='bold', va='top', ha='right')

    fig.text(0.55, 0.01, 'Time (s)', fontsize=16)
    plt.tight_layout()
    plt.savefig(figname)
    # return fig, axs

# v0.00 - tools for plotting simulation data, compare K diffusion in two conds
# v0.01 - generalized compareKdiffuse for arbitrary list of data folders
# v0.02 - added functions for plotting raster sorted by radial distance from center and heat map of spike frequency
# v0.03 - added function for plotting Vmemb heat maps
# v0.04 - parallelized Vmemb heat map analysis
# v0.05 - added function for computing rheobase from stimulated cells
# v0.06 - added raster function sans plotting 
# v0.07 - added plotting K+ wave position for multiple sims, working on combo Vmemb and raster
# v0.08 - combined heatmaps for vmemb and spike freq for single figure  
# v0.09 - add times to gifs
# v0.10 - fig with combo of k diffusion and 
# v0.11 - single function for combo figure(s)
# v0.12 - added function for plotting example traces from single run
# v0.13 - combining recordings from different/continued runs, testing with example traces func
# v0.13.1 - combining fragmented runs for rasters and membrane potential files 
# v0.14 - getRaster() combines recks.pkl and membrane_potential.pkl files 
# v0.15 - compute wave speeds for spike wave and K+ wave 
# v0.16 - enable subplots for compareKwaves()
# v0.16.1 - enable fragmented sims for species diffusion gifs
# v0.17 - replace nans with 0 in spike freq heatmap
# v0.17.1 - specification of max for spk freq heatmap and switch to perceptually uniform cmap
# v0.17.2 - ditch linear fit for wave position / speed computation
# v0.17.3 - user can specify whether to plot cells in center of periphery of slice
# v0.17.4 - user specifies center or periphery for combo vmemb and spike heatmap
# v0.17.5 - raster needs to differentiate between uniformly distributed recs and single layer
# v0.18 - introduced function for computing spike metrics (e.g. frequency, duration of spiking)
# v0.19 - changing all species gif to mp4
# v0.20 - sinkVsSource for looking at periods where cells act as K+ sink 
# v0.21 - new method for K+ wave speed in core vs periphery of slice
# v1.0 - minor updates to raster ploting and mp4 functions 
# v1.1 - added doc strings for most functions

Loading data, please wait...