Spike-Timing-Based Computation in Sound Localization (Goodman and Brette 2010)

 Download zip file 
Help downloading and running models
Accession:126465
" ... In neuron models consisting of spectro-temporal filtering and spiking nonlinearity, we found that the binaural structure induced by spatialized sounds is mapped to synchrony patterns that depend on source location rather than on source signal. Location-specific synchrony patterns would then result in the activation of location-specific assemblies of postsynaptic neurons. We designed a spiking neuron model which exploited this principle to locate a variety of sound sources in a virtual acoustic environment using measured human head-related transfer functions. ..."
Reference:
1 . Goodman DF, Brette R (2010) Spike-timing-based computation in sound localization. PLoS Comput Biol 6:e1000993 [PubMed]
Citations  Citation Browser
Model Information (Click on a link to find other models with that property)
Model Type: Realistic Network;
Brain Region(s)/Organism:
Cell Type(s):
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: Brian; Python;
Model Concept(s): Coincidence Detection; Synchronization;
Implementer(s): Goodman, Dan F. M. ;
from shared import *
from hrtf_analysis import *
from models import *
import gc

class AllPairsModel(object):
    '''
    Initialise this object with an hrtfset, a cochlear range (cfmin, cfmax, cfN),
    a range of gains (gain_max in dB, gain_N) and a range of delays (delay_max,
    delay_N),
    and optionally:
    a model for the coincidence detector neurons (cd_model),
    a model for the filter neurons (filtergroup_model).
        
    The __call__ method returns a count (see docstring of that method). 
    '''
    def __init__(self, hrtfset, cfmin, cfmax, cfN,
                 gain_max, gain_N, delay_max, delay_N,
                 cd_model=standard_cd_model,
                 filtergroup_model=standard_filtergroup_model,
                 ):
        self.hrtfset = hrtfset
        self.cfmin, self.cfmax, self.cfN = cfmin, cfmax, cfN
        self.cd_model = cd_model
        self.filtergroup_model = filtergroup_model
        self.gain_max = gain_max
        self.gain_N = gain_N
        self.delay_max = delay_max
        self.delay_N = delay_N
        
        self.num_indices = num_indices = hrtfset.num_indices
        cf = erbspace(cfmin, cfmax, cfN)
                
        # dummy sound, when we run apply() we replace it
        sound = Sound((silence(1*ms), silence(1*ms)))
        soundinput = DoNothingFilterbank(sound)

        # prepare gains filter
        m = (gain_N+1)/2
        gains_dB = linspace(0, gain_max, m)
        gains = 10**(gains_dB/20)
        gains = hstack((1/gains[::-1], gains[1:]))
        allgains = reshape(gains, (1, 1, gains.size))

        def apply_gains(y):
            nsamples = y.shape[0]
            cfN = y.shape[1]/2
            y = reshape(y, (nsamples, 2*cfN, 1))            
            y1 = y[:, :cfN, :]*allgains
            y2 = y[:, cfN:, :]*allgains[:, :, ::-1]
            y = hstack((y1, y2))
            y = reshape(y, (nsamples, y.size/nsamples))
            return y
        
        gfb = Gammatone(Repeat(soundinput, cfN), hstack((cf, cf)))
                
        gains_fb = FunctionFilterbank(gfb, apply_gains)
        gains_fb.nchannels = gfb.nchannels*gain_N
        
        compress = filtergroup_model['compress']
        cochlea = FunctionFilterbank(gains_fb, lambda x:compress(clip(x, 0, Inf)))
        
        # Create the filterbank group
        eqs = Equations(filtergroup_model['eqs'], **filtergroup_model['parameters'])
        G = FilterbankGroup(cochlea, 'target_var', eqs,
                            threshold=filtergroup_model['threshold'],
                            reset=filtergroup_model['reset'],
                            refractory=filtergroup_model['refractory'])
        
        # create the synchrony group
        cd_eqs = Equations(cd_model['eqs'], **cd_model['parameters'])
        cd = NeuronGroup(cfN*gain_N*(delay_N*2-1), cd_eqs,
                         threshold=cd_model['threshold'],
                         reset=cd_model['reset'],
                         refractory=cd_model['refractory'],
                         clock=G.clock)
        
        # set up the synaptic connectivity
        left_delays = hstack((zeros(delay_N-1), linspace(0, float(delay_max), delay_N)))
        right_delays = left_delays[::-1]
        cd_weight = cd_model['weight']
        C = Connection(G, cd, 'target_var', delay=True, max_delay=delay_max)
        for i, j, dl, dr in zip(repeat(arange(cfN*gain_N), 2*delay_N-1),
                                arange(cfN*gain_N*(delay_N*2-1)),
                                tile(left_delays, cfN*gain_N),
                                tile(right_delays, cfN*gain_N)):
            C[i, j] = cd_weight
            C[i+cfN*gain_N, j] = cd_weight
            C.delay[i, j] = dl
            C.delay[i+cfN*gain_N, j] = dr

        self.soundinput = soundinput
        self.filtergroup = G
        self.synchronygroup = cd
        self.synapses = C
        self.counter = SpikeCounter(cd)
        self.network = Network(G, cd, C, self.counter)
        
    def __call__(self, sound, index=None, **indexkwds):
        '''
        Apply all pairs filtering group to given sound, which should be a
        stereo sound unless you specify the HRTF index, or coordinates of
        the HRTF index as keyword arguments, in which case it should be a mono
        sound which will have the given HRTF applied to it. You can also
        specify index=hrtf. Returns the count of the neurons in the synchrony
        group with shape (cfN, gain_N, delay_N*2-1).
        '''
        hrtf = None
        if index is not None:
            hrtf = self.hrtfset[index]
        elif isinstance(index, HRTF):
            hrtf = index
        elif len(indexkwds):
            hrtf = self.hrtfset(**indexkwds)
        if hrtf is not None:
            sound = hrtf(sound)
        self.soundinput.source = sound
        self.network.reinit()
        self.filtergroup_model['init'](self.filtergroup,
                                       self.filtergroup_model['parameters'])
        self.cd_model['init'](self.synchronygroup, self.cd_model['parameters'])
        self.network.run(sound.duration, report='stderr')
        count = reshape(self.counter.count,
                        (self.cfN, self.gain_N, self.delay_N*2-1))
        return count

if __name__=='__main__':
    
    from plot_count import ircam_plot_count

    hrtfdb = get_ircam()
    subject = 1002
    hrtfset = hrtfdb.load_subject(subject)
    index = randint(hrtfset.num_indices)
    cfmin, cfmax, cfN = 150*Hz, 5*kHz, 80
    gain_max, gain_N = 8.0, 61
    delay_N = 35
    delay_max = delay_N/samplerate
    # Change this to 10*second for equivalent picture to the paper
    sound = whitenoise(200*ms).atlevel(80*dB)
    
    apmodel = AllPairsModel(hrtfset, cfmin, cfmax, cfN,
                            gain_max, gain_N, delay_max, delay_N)
    
    count = apmodel(sound, index)

    # Complicated code to plot the output nicely
    freqlabels = array([150*Hz, 1*kHz, 2*kHz, 3*kHz, 4*kHz, 5*kHz])
    fig_mew = 1 # marker edge width (in points)
    num_indices = hrtfset.num_indices
    from scipy.ndimage.filters import *
    itd, ild = hrtfset_itd_ild(hrtfset, cfmin, cfmax, cfN)
    delays = array([itd[index][i] for i in xrange(cfN)])
    gains = array([ild[index][i] for i in xrange(cfN)])
    gains = 20*log10(gains)
    delays = -array(delays*samplerate, dtype=int)+delay_N-1
    arrgains = linspace(-gain_max, gain_max, gain_N)
    gains = digitize(gains, 0.5*(arrgains[1:]+arrgains[:-1]))
    gains = gain_N-1-gains
    def dofig(count, blur=0, blurmode='reflect', freqlabels=None):
        count = array(count, dtype=float) 
        ocount = count
        count = copy(ocount)
        count.shape = (cfN, gain_N, delay_N*2-1)
        count = amax(count, axis=1)
        count.shape = (cfN, delay_N*2-1)
        subplot(121)
        count = gaussian_filter(count, blur, mode=blurmode)
        imshow(count, origin='lower left', interpolation='nearest', aspect='auto',
               extent=(-float(delay_N/samplerate/msecond), float(delay_N/samplerate/msecond), 0, cfN))
        plot((delays-delay_N)/samplerate/msecond, arange(cfN), '+', color=(0,0,0), mew=fig_mew)
        plot((argmax(count, axis=1)-delay_N)/samplerate/msecond, arange(cfN), 'x', color=(1,1,1), mew=fig_mew)
        axis((float(-delay_N/samplerate/msecond), float(delay_N/samplerate/msecond), 0, cfN))
        xlabel('Delay (ms)')
        if freqlabels is None:
            yticks([])
            ylabel('Channel')
        else:
            cf = erbspace(cfmin, cfmax, cfN)
            j = digitize(freqlabels, .5*(cf[1:]+cf[:-1]))
            yticks(j, map(str, array(freqlabels, dtype=int)))
            ylabel('Channel (Hz)')
        subplot(122)
        count = copy(ocount)
        count.shape = (cfN, gain_N, delay_N*2-1)
        count = amax(count, axis=2)
        count.shape = (cfN, gain_N)
        count = gaussian_filter(count, blur, mode=blurmode)
        imshow(count, origin='lower left', interpolation='nearest', aspect='auto')
        plot(gains, arange(cfN), '+', color=(0,0,0), mew=fig_mew)
        plot(argmax(count, axis=1), arange(cfN), 'x', color=(1,1,1), mew=fig_mew)
        axis('tight')
        xlabel('Relative gain (dB)')
        xticks([0, (gain_N-1)/2, gain_N-1], [str(min(arrgains)), '0', str(max(arrgains))])
        if freqlabels is None:
            yticks([])
            ylabel('Channel')
        else:
            cf = erbspace(cfmin, cfmax, cfN)
            j = digitize(freqlabels, .5*(cf[1:]+cf[:-1]))
            yticks(j, map(str, array(freqlabels, dtype=int)))
            ylabel('Channel (Hz)')
    dofig(count, freqlabels=freqlabels)
    figure()
    dofig(count, blur=1)#, freqlabels=[500, 1000, 2000, 3000, 4000, 5000])
    figure()
    dofig(count, blur=2)
    show()