from shared import *
from hrtf_analysis import *
from models import *
import gc
class IdealFilteringModel(object):
'''
Initialise this object with an hrtfset, a cochlear range (cfmin, cfmax, cfN),
and optionally:
a model for the coincidence detector neurons (cd_model),
a model for the filter neurons (filtergroup_model),
whether or not to normalise the cochlear-filtered HRTFs, which improves
performance by making each frequency band have the same power (and therefore
comparable firing rates in the neurons) (use_normalisation_gains).
The __call__ method returns a count (see docstring of that method).
'''
def __init__(self, hrtfset, cfmin, cfmax, cfN,
cd_model=standard_cd_model,
filtergroup_model=standard_filtergroup_model,
use_normalisation_gains=True,
):
self.hrtfset = hrtfset
self.cfmin, self.cfmax, self.cfN = cfmin, cfmax, cfN
self.cd_model = cd_model
self.filtergroup_model = filtergroup_model
self.num_indices = num_indices = hrtfset.num_indices
cf = erbspace(cfmin, cfmax, cfN)
# dummy sound, when we run apply() we replace it
sound = Sound((silence(1*ms), silence(1*ms)))
soundinput = DoNothingFilterbank(sound)
hrtfset_fb = hrtfset.filterbank(
RestructureFilterbank(soundinput,
indexmapping=repeat([1, 0], hrtfset.num_indices)))
# We normalise the different HRTFs because we don't want a stronger
# response from channels with less attenuation in the HRTF, but rather
# a stronger response when the filters are more closely equal
if use_normalisation_gains:
attenuations = hrtfset_attenuations(cfmin, cfmax, cfN, hrtfset)
#shape: (2, hrtfset.num_indices, cfN))
gains_max = reshape(1/maximum(attenuations[0], attenuations[1]), (1, hrtfset.num_indices, cfN))
gains = vstack((gains_max, gains_max))
gains.shape = gains.size
func = lambda x: x*gains
else:
func = lambda x: x
gains_fb = FunctionFilterbank(Repeat(hrtfset_fb, cfN), func)
gfb = Gammatone(gains_fb,
tile(cf, hrtfset_fb.nchannels))
compress = filtergroup_model['compress']
cochlea = FunctionFilterbank(gfb, lambda x:compress(clip(x, 0, Inf)))
# Create the filterbank group
eqs = Equations(filtergroup_model['eqs'], **filtergroup_model['parameters'])
G = FilterbankGroup(cochlea, 'target_var', eqs,
threshold=filtergroup_model['threshold'],
reset=filtergroup_model['reset'],
refractory=filtergroup_model['refractory'])
# create the synchrony group
cd_eqs = Equations(cd_model['eqs'], **cd_model['parameters'])
cd = NeuronGroup(num_indices*cfN, cd_eqs,
threshold=cd_model['threshold'],
reset=cd_model['reset'],
refractory=cd_model['refractory'],
clock=G.clock)
# set up the synaptic connectivity
cd_weight = cd_model['weight']
C = Connection(G, cd, 'target_var')
for i in xrange(num_indices*cfN):
C[i, i] = cd_weight
C[i+num_indices*cfN, i] = cd_weight
self.soundinput = soundinput
self.filtergroup = G
self.synchronygroup = cd
self.synapses = C
self.counter = SpikeCounter(cd)
self.network = Network(G, cd, C, self.counter)
def __call__(self, sound, index=None, **indexkwds):
'''
Apply ideal filtering group to given sound, which should be a
stereo sound unless you specify the HRTF index, or coordinates of
the HRTF index as keyword arguments, in which case it should be a mono
sound which will have the given HRTF applied to it. You can also
specify index=hrtf. Returns the spike count of the neurons in the synchrony
group with shape (cfN, num_indices).
'''
hrtf = None
if index is not None:
hrtf = self.hrtfset[index]
elif isinstance(index, HRTF):
hrtf = index
elif len(indexkwds):
hrtf = self.hrtfset(**indexkwds)
if hrtf is not None:
sound = hrtf(sound)
self.soundinput.source = sound
self.network.reinit()
self.filtergroup_model['init'](self.filtergroup,
self.filtergroup_model['parameters'])
self.cd_model['init'](self.synchronygroup, self.cd_model['parameters'])
self.network.run(sound.duration, report='stderr')
count = reshape(self.counter.count, (self.num_indices, self.cfN)).T
return count
if __name__=='__main__':
from plot_count import ircam_plot_count
hrtfdb = get_ircam()
subject = 1002
hrtfset = hrtfdb.load_subject(subject)
index = randint(hrtfset.num_indices)
cfmin, cfmax, cfN = 150*Hz, 5*kHz, 80
sound = whitenoise(500*ms)
ifmodel = IdealFilteringModel(hrtfset, cfmin, cfmax, cfN)
count = ifmodel(sound, index)
ircam_plot_count(hrtfset, count, index=index)
show()