Spike-Timing-Based Computation in Sound Localization (Goodman and Brette 2010)

 Download zip file 
Help downloading and running models
Accession:126465
" ... In neuron models consisting of spectro-temporal filtering and spiking nonlinearity, we found that the binaural structure induced by spatialized sounds is mapped to synchrony patterns that depend on source location rather than on source signal. Location-specific synchrony patterns would then result in the activation of location-specific assemblies of postsynaptic neurons. We designed a spiking neuron model which exploited this principle to locate a variety of sound sources in a virtual acoustic environment using measured human head-related transfer functions. ..."
Reference:
1 . Goodman DF, Brette R (2010) Spike-timing-based computation in sound localization. PLoS Comput Biol 6:e1000993 [PubMed]
Citations  Citation Browser
Model Information (Click on a link to find other models with that property)
Model Type: Realistic Network;
Brain Region(s)/Organism:
Cell Type(s):
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: Brian; Python;
Model Concept(s): Coincidence Detection; Synchronization;
Implementer(s): Goodman, Dan F. M. ;
from shared import *
from hrtf_analysis import *
from models import *
import gc

class ApproximateFilteringModel(object):
    '''
    Initialise this object with an hrtfset, a cochlear range (cfmin, cfmax, cfN),
    and optionally:
    a model for the coincidence detector neurons (cd_model),
    a model for the filter neurons (filtergroup_model),
    whether or not to use the correlation-based best delays (use_delays),
    whether or not to use the best gains (use_gains),
    whether or not to use only the phase information (delays between -pi and pi),
    an alternative set of itd/ild pairs (itdild, see the file hrtf_analysis.py
    for more information on this, function hrtfset_itd_ild).
    
    The __call__ method returns a count (see docstring of that method). 
    '''
    def __init__(self, hrtfset, cfmin, cfmax, cfN,
                 cd_model=standard_cd_model,
                 filtergroup_model=standard_filtergroup_model,
                 use_delays=True, use_gains=True, use_only_phase=False,
                 itdild=None,
                 ):
        self.hrtfset = hrtfset
        self.cfmin, self.cfmax, self.cfN = cfmin, cfmax, cfN
        self.cd_model = cd_model
        self.filtergroup_model = filtergroup_model
        
        self.num_indices = num_indices = hrtfset.num_indices
        cf = erbspace(cfmin, cfmax, cfN)
        
        # extract ITDs/ILDs in the right form
        if itdild is None:
            all_itds, all_ilds = hrtfset_itd_ild(hrtfset, cfmin, cfmax, cfN)
        else:
            all_itds, all_ilds = itdild
        d = array([all_itds[j][i] for i in xrange(cfN) for j in xrange(num_indices)])
        g = array([all_ilds[j][i] for i in xrange(cfN) for j in xrange(num_indices)])
        gains = hstack((1/g, g))
        gains_dB = 20*log10(gains)
        abs_gains_dB = abs(gains_dB)
        r = -abs_gains_dB[:len(gains)/2]
        r = hstack((r, r))
        gains_dB += r
        gains = 10**(gains_dB/20)
        gains = reshape(gains, (1, len(gains)))

        if use_only_phase:
            d_cf = repeat(cf, num_indices)
            d = imag(log(exp(1j*2*pi*d*d_cf)))/(2*pi*d_cf) # delays constrained to having their phase be in [-pi, pi] 

        delays_L = where(d>=0, zeros(len(d)), -d)
        delays_R = where(d>=0, d, zeros(len(d)))
        delay_max = max(amax(delays_L), amax(delays_R))*second

        if not use_gains:
            gains = ones(len(gains))

        if not use_delays:
            delays_L = delays_R = zeros(len(d))
            delay_max = 2/samplerate
        
        # dummy sound, when we run apply() we replace it
        sound = Sound((silence(1*ms), silence(1*ms)))
        soundinput = DoNothingFilterbank(sound)
        
        gfb = Gammatone(Repeat(soundinput, cfN), hstack((cf, cf)))
        
        gains_fb = FunctionFilterbank(Repeat(gfb, num_indices),
                                      lambda x:x*gains)
        
        compress = filtergroup_model['compress']
        cochlea = FunctionFilterbank(gains_fb, lambda x:compress(clip(x, 0, Inf)))
        
        # Create the filterbank group
        eqs = Equations(filtergroup_model['eqs'], **filtergroup_model['parameters'])
        G = FilterbankGroup(cochlea, 'target_var', eqs,
                            threshold=filtergroup_model['threshold'],
                            reset=filtergroup_model['reset'],
                            refractory=filtergroup_model['refractory'])
        
        # create the synchrony group
        cd_eqs = Equations(cd_model['eqs'], **cd_model['parameters'])
        cd = NeuronGroup(num_indices*cfN, cd_eqs,
                         threshold=cd_model['threshold'],
                         reset=cd_model['reset'],
                         refractory=cd_model['refractory'],
                         clock=G.clock)
        
        # set up the synaptic connectivity
        cd_weight = cd_model['weight']
        C = Connection(G, cd, 'target_var', delay=True, max_delay=delay_max)
        for i in xrange(num_indices*cfN):
            C[i, i] = cd_weight
            C[i+num_indices*cfN, i] = cd_weight
            C.delay[i, i] = delays_L[i]
            C.delay[i+cfN*num_indices, i] = delays_R[i]

        self.soundinput = soundinput
        self.filtergroup = G
        self.synchronygroup = cd
        self.synapses = C
        self.counter = SpikeCounter(cd)
        self.network = Network(G, cd, C, self.counter)
        
    def __call__(self, sound, index=None, **indexkwds):
        '''
        Apply approximate filtering group to given sound, which should be a
        stereo sound unless you specify the HRTF index, or coordinates of
        the HRTF index as keyword arguments, in which case it should be a mono
        sound which will have the given HRTF applied to it. You can also
        specify index=hrtf. Returns the spike count of the neurons in the synchrony
        group with shape (cfN, num_indices).
        '''
        hrtf = None
        if index is not None:
            hrtf = self.hrtfset[index]
        elif isinstance(index, HRTF):
            hrtf = index
        elif len(indexkwds):
            hrtf = self.hrtfset(**indexkwds)
        if hrtf is not None:
            sound = hrtf(sound)
        self.soundinput.source = sound
        self.network.reinit()
        self.filtergroup_model['init'](self.filtergroup,
                                       self.filtergroup_model['parameters'])
        self.cd_model['init'](self.synchronygroup, self.cd_model['parameters'])
        self.network.run(sound.duration, report='stderr')
        count = reshape(self.counter.count, (self.cfN, self.num_indices))
        return count

if __name__=='__main__':
    
    from plot_count import ircam_plot_count

    hrtfdb = get_ircam()
    subject = 1002
    hrtfset = hrtfdb.load_subject(subject)
    index = randint(hrtfset.num_indices)
    cfmin, cfmax, cfN = 150*Hz, 5*kHz, 80
    sound = whitenoise(500*ms)
    
    afmodel = ApproximateFilteringModel(hrtfset, cfmin, cfmax, cfN)
    
    count = afmodel(sound, index)
    
    ircam_plot_count(hrtfset, count, index=index)
    show()