Spike-Timing-Based Computation in Sound Localization (Goodman and Brette 2010)

 Download zip file 
Help downloading and running models
Accession:126465
" ... In neuron models consisting of spectro-temporal filtering and spiking nonlinearity, we found that the binaural structure induced by spatialized sounds is mapped to synchrony patterns that depend on source location rather than on source signal. Location-specific synchrony patterns would then result in the activation of location-specific assemblies of postsynaptic neurons. We designed a spiking neuron model which exploited this principle to locate a variety of sound sources in a virtual acoustic environment using measured human head-related transfer functions. ..."
Reference:
1 . Goodman DF, Brette R (2010) Spike-timing-based computation in sound localization. PLoS Comput Biol 6:e1000993 [PubMed]
Citations  Citation Browser
Model Information (Click on a link to find other models with that property)
Model Type: Realistic Network;
Brain Region(s)/Organism:
Cell Type(s):
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: Brian; Python;
Model Concept(s): Coincidence Detection; Synchronization;
Implementer(s): Goodman, Dan F. M. ;
from shared import *
from hrtf_analysis import *
from models import *
import gc

class IdealFilteringModel(object):
    '''
    Initialise this object with an hrtfset, a cochlear range (cfmin, cfmax, cfN),
    and optionally:
    a model for the coincidence detector neurons (cd_model),
    a model for the filter neurons (filtergroup_model),
    whether or not to normalise the cochlear-filtered HRTFs, which improves
    performance by making each frequency band have the same power (and therefore
    comparable firing rates in the neurons) (use_normalisation_gains).
    
    The __call__ method returns a count (see docstring of that method). 
    '''
    def __init__(self, hrtfset, cfmin, cfmax, cfN,
                 cd_model=standard_cd_model,
                 filtergroup_model=standard_filtergroup_model,
                 use_normalisation_gains=True,
                 ):
        self.hrtfset = hrtfset
        self.cfmin, self.cfmax, self.cfN = cfmin, cfmax, cfN
        self.cd_model = cd_model
        self.filtergroup_model = filtergroup_model
        
        self.num_indices = num_indices = hrtfset.num_indices
        cf = erbspace(cfmin, cfmax, cfN)
                
        # dummy sound, when we run apply() we replace it
        sound = Sound((silence(1*ms), silence(1*ms)))
        soundinput = DoNothingFilterbank(sound)
        
        hrtfset_fb = hrtfset.filterbank(
                RestructureFilterbank(soundinput, 
                        indexmapping=repeat([1, 0], hrtfset.num_indices)))

        # We normalise the different HRTFs because we don't want a stronger
        # response from channels with less attenuation in the HRTF, but rather
        # a stronger response when the filters are more closely equal
        if use_normalisation_gains:
            attenuations = hrtfset_attenuations(cfmin, cfmax, cfN, hrtfset)
            #shape: (2, hrtfset.num_indices, cfN))
            gains_max = reshape(1/maximum(attenuations[0], attenuations[1]), (1, hrtfset.num_indices, cfN))
            gains = vstack((gains_max, gains_max))
            gains.shape = gains.size
            func = lambda x: x*gains
        else:
            func = lambda x: x

        gains_fb = FunctionFilterbank(Repeat(hrtfset_fb, cfN), func)

        gfb = Gammatone(gains_fb,
                        tile(cf, hrtfset_fb.nchannels))
        
        compress = filtergroup_model['compress']
        cochlea = FunctionFilterbank(gfb, lambda x:compress(clip(x, 0, Inf)))
        
        # Create the filterbank group
        eqs = Equations(filtergroup_model['eqs'], **filtergroup_model['parameters'])
        G = FilterbankGroup(cochlea, 'target_var', eqs,
                            threshold=filtergroup_model['threshold'],
                            reset=filtergroup_model['reset'],
                            refractory=filtergroup_model['refractory'])
        
        # create the synchrony group
        cd_eqs = Equations(cd_model['eqs'], **cd_model['parameters'])
        cd = NeuronGroup(num_indices*cfN, cd_eqs,
                         threshold=cd_model['threshold'],
                         reset=cd_model['reset'],
                         refractory=cd_model['refractory'],
                         clock=G.clock)
        
        # set up the synaptic connectivity
        cd_weight = cd_model['weight']
        C = Connection(G, cd, 'target_var')
        for i in xrange(num_indices*cfN):
            C[i, i] = cd_weight
            C[i+num_indices*cfN, i] = cd_weight

        self.soundinput = soundinput
        self.filtergroup = G
        self.synchronygroup = cd
        self.synapses = C
        self.counter = SpikeCounter(cd)
        self.network = Network(G, cd, C, self.counter)
        
    def __call__(self, sound, index=None, **indexkwds):
        '''
        Apply ideal filtering group to given sound, which should be a
        stereo sound unless you specify the HRTF index, or coordinates of
        the HRTF index as keyword arguments, in which case it should be a mono
        sound which will have the given HRTF applied to it. You can also
        specify index=hrtf. Returns the spike count of the neurons in the synchrony
        group with shape (cfN, num_indices).
        '''
        hrtf = None
        if index is not None:
            hrtf = self.hrtfset[index]
        elif isinstance(index, HRTF):
            hrtf = index
        elif len(indexkwds):
            hrtf = self.hrtfset(**indexkwds)
        if hrtf is not None:
            sound = hrtf(sound)
        self.soundinput.source = sound
        self.network.reinit()
        self.filtergroup_model['init'](self.filtergroup,
                                       self.filtergroup_model['parameters'])
        self.cd_model['init'](self.synchronygroup, self.cd_model['parameters'])
        self.network.run(sound.duration, report='stderr')
        count = reshape(self.counter.count, (self.num_indices, self.cfN)).T
        return count

if __name__=='__main__':
    
    from plot_count import ircam_plot_count

    hrtfdb = get_ircam()
    subject = 1002
    hrtfset = hrtfdb.load_subject(subject)
    index = randint(hrtfset.num_indices)
    cfmin, cfmax, cfN = 150*Hz, 5*kHz, 80
    sound = whitenoise(500*ms)
    
    ifmodel = IdealFilteringModel(hrtfset, cfmin, cfmax, cfN)
    
    count = ifmodel(sound, index)
    
    ircam_plot_count(hrtfset, count, index=index)
    show()