# spike read spikes and print using # matplotlib the data # hebbian step func. aLTP=2e-2 aLTD=2e-2 tauLTP=20. tauLTD=20. delay_post=3 delay_pre=1 wmax=1. ltdinvl_excit = 250. ltpinvl_excit = 33.33 sighalf_excit = 25.0 ltdinvl_inhib = 250. ltpinvl_inhib = 33.33 sighalf_inhib = 25.0 import params def syn_hebb(ispre, w, tpre, tpost, P, M, t): from math import exp if ispre: P=P*exp((tpre-t)/tauLTP)+aLTP interval=tpost-t dw=wmax*M*exp(interval/tauLTD) else: M=M*exp((tpost-t)/tauLTD)-aLTD interval=t-tpre dw=wmax*P*exp(-interval/tauLTP) w += dw if w > wmax: w = wmax elif w < 0: w = 0 return w, P, M def hebbian(winit, trpre, trpost): # this destroy vector data for i in range(len(trpre)): trpre[i] += delay_pre for i in range(len(trpost)): trpost[i] += delay_post t=[0] w=[winit] P=M=0 i=j=1 while i 2 * sighalf: s = 2 * sighalf elif s < 0: s = 0 return s def syn_weights(t, winit=0, excit=True): w = [ winit ] for i in range(1, len(t)): w.append(syn_step(w[-1], t[i] - t[i-1], excit)) return w class SpikesReader: def __init__(self, filename, *args): self.sort = False from struct import unpack self.cache_max_len = 100 self.tstop = None self.bincoded = filename.endswith('.spk') or filename.endswith('.spk2') self.__initweights = {} # initial weights if len(args) > 0: f = open(args[0], 'r') line = f.readline() while line: gid, s = line.split()[:2] gid = int(gid) s = int(s) self.__initweights.update({ gid:s }) line = f.readline() f.close() if self.bincoded: # init for binary format self.header = {} self.fi = open(filename, 'rb') offset = unpack('>q', self.fi.read(8))[0] # read the time if filename.endswith('.spk2'): self.tstop = unpack('>f', self.fi.read(4))[0] hlen = offset / 8 offset += 4 else: hlen = offset / 8 offset += 8 for j in range(hlen): gid, n = unpack('>LL', self.fi.read(8)) # read if not self.header.has_key(gid): self.header.update({ gid:[(offset, n)] }) else: self.header[gid].append((offset, n)) offset += n * 4 else: # init for textual format self.fi = open(filename, 'r') self.__cache = {} self.__old = [] def retrieve(self, gid): # if gid in cache don't retrieve if gid not in self.__cache: # clean the oldest if len(self.__cache) >= self.cache_max_len: del self.__cache[self.__old[0]] del self.__old[0] # read t = [ ] if self.bincoded: # binary format reading code offset, n = self.header[gid][-1] # for offset, n in self.header[gid]: self.fi.seek(offset) from struct import unpack # for i in range(n): # t.append(unpack('>f', self.fi.read(4))[0]) t = list(unpack('>' + 'f'*n, self.fi.read(4*n))) else: # if not bincoded # it's the old textual format self.fi.seek(1) line = self.fi.readline() while line: tks = line.split() if int(tks[1]) == gid: t.append(float(tks[0])) line = self.fi.readline() if len(t) == 0: raise KeyError # only for errors... if self.sort: t = sorted(t) self.__old.append(gid) self.__cache.update({ gid:t }) from copy import copy return copy(self.__cache[gid]) def freqvssniff(self, gid, tstart=50.0): t = [ tstart+params.sniff_invl*0.5 ] nspk = [ 0 ] for x in self.retrieve(gid): i = int((x-tstart)/params.sniff_invl) if i >= len(t): t.append(t[-1]+params.sniff_invl*0.5) nspk.append(0) nspk[-1] += 1*1000.0/params.sniff_invl return t, nspk def frequency(self, gid): t = [ 0. ] + self.retrieve(gid) fr = [ 0. ] for i in range(1, len(t)): fr.append(1000. / (t[i] - t[i - 1])) return t, fr def stepvssniff(self, gid, tstart=50.0): return self.step(gid, dt=params.sniff_invl) def step(self, gid, dt=None, tlast=50.0): from mgrs import gid_mgrs_begin if gid < gid_mgrs_begin: return None if gid%2!=0 and params.use_fi_stdp: if self.__initweights.has_key(gid): wi=self.__initweights[gid] else: wi=0 tpre=[0.] if self.header.has_key(gid): tpre += self.retrieve(gid) tpost=[0.] if self.header.has_key(gid): tpost += self.retrieve(gid+1) t,w = hebbian(wi,tpre,tpost) return t,w else: t = [ 0. ] try: s = [ self.__initweights[gid] ] except KeyError: try: if gid%2: init_weight = params.init_inh_weight else: init_weight = params.init_exc_weight except: init_weight = 0 s = [ init_weight ] try: t += self.retrieve(gid) except KeyError: return t, s if dt == None: for i in range(1, len(t)): s.append(syn_step(s[-1], t[i] - t[i-1], excit=(gid%2 == 0))) return t, s _t = [ ] _s = [ ] for i in range(1, len(t)): s.append(syn_step(s[-1], t[i] - t[i-1], excit=(gid%2 == 0))) if t[i]+params.sniff_invl > tlast: _t.append(tlast) _s.append(s[-1]) tlast += params.sniff_invl return _t, _s def close(self): self.fi.close() class SpikesWriter: def __init__(self, filename, tstop): self.filename = filename self.__fo = open(filename + '.data', 'wb') self.header = {} self.tstop = tstop def write(gid, t): from struct import pack self.header[gid] = len(t) self.__fo.write(pack('>'+('f'*len(t)), t)) def close(self, filename): self.__fo.close() from struct import pack fo = open(self.filename + '.time', 'wb') fo.write(pack('>f', self.tstop)) fo.close() # write header fo = open(self.filename + '.header', 'wb') for x in self.header.items(): fo.write(pack('>LL', x)) fo.close() from os import path fo = open(self.filename + '.size', 'wb') fo.write(pack('>q', path.getsize(self.filename + '.header'))) fo.close() # read time stop tstop = 20050 try: from sys import argv tstop = float(argv[argv.index('-tstop') + 1]) except: pass # @@@@@@@@@@@@@@ def show(sr, gids, xlabel, ylabel, call, title, ylim, legend=True): if len(gids) == 0: return from bindict import query as descr import matplotlib.pyplot as plt plt.figure() color = [ 'b', 'g', 'r', 'c', 'm', 'y', 'k' ] never_drawed = False for i, g in enumerate(gids): never_drawed = never_drawed | call(g, i, color[i % len(color)], descr(g)[-1]) if not never_drawed: plt.close() return False if legend: plt.legend().draggable() plt.ylabel(ylabel) plt.xlabel(xlabel) plt.title(title) if len(ylim) == 2: plt.ylim(ylim) if sr.tstop: plt.xlim([ 0, sr.tstop ]) elif tstop: plt.xlim([ 0, tstop ]) plt.draw() return True def show_raster(sr, gids): import matplotlib.pyplot as plt def raster(gid, i, col, descr): try: t = sr.retrieve(gid) plt.scatter(t, [ i ] * len(t), s=10, marker='|', label=descr, c=col) except KeyError: return False return True return show(sr, gids, 'spike time (ms)', '', raster, 'Spike raster', [ -1, len(gids) + 1 ]) def show_freqs(sr, gids): import matplotlib.pyplot as plt def freq(gid, i, col, descr): try: t, fr = sr.frequency(gid) plt.plot(t, fr, '-' + col + 'o', label=descr) except KeyError: return False return True return show(sr, gids, 'spike time (ms)', 'Freq. (Hz)', freq, 'Frequency', []) def show_weights(sr, gids): from mgrs import gid_mgrs_begin # not weights gids = gids.difference(set(range(gid_mgrs_begin))) import matplotlib.pyplot as plt def step(gid, i, col, descr): try: t, d = sr.step(gid) if gid%2==0: maxsig=2*sighalf_excit elif params.use_fi_stdp: maxsig=wmax else: maxsig=2*sighalf_inhib for i in range(len(d)): d[i] = d[i]/maxsig #* maxsig plt.plot(t, d, col + '-', label=descr) except KeyError: return False return True return show(sr, gids, 'spike time (ms)', 'Step', step, 'Syn. Steps', [-0.1, 1.1])#[ -1, 2 * max(sighalf_inhib, sighalf_excit) + 1]) def show_evol(sr, gids, tstart=50.0): import matplotlib.pyplot as plt def evol(gid, i, col, descr): try: if gid%2: ltpinvl=ltpinvl_inhib ltdinvl=ltdinvl_inhib else: ltpinvl=ltpinvl_excit ltdinvl=ltdinvl_excit t, fr = sr.frequency(gid) dw = [0]*len(fr) isniff = 0 lastdw = 0 for i in range(1, len(t)): _isniff = int(t[i]/params.sniff_invl) if _isniff > isniff: isniff = _isniff lastdw = 0 if fr[i] >= 1000/ltpinvl: dw[i] = lastdw + 1 elif fr[i] >= 1000/ltdinvl: dw[i] = lastdw - 1 else: dw[i] = lastdw lastdw = dw[i] plt.plot(t, dw, '-' + col + 'o', label=descr) except KeyError: return False return True return show(sr, gids, 'spike time (ms)', 'DStep', evol, 'Evolution', []) # main history if __name__ == '__main__': from sys import argv i = argv.index('-i') gids = set() for sg in argv[argv.index('-gid') + 1:]: try: gids.add(int(sg)) except ValueError: break sr = SpikesReader(argv[i + 1]) # show all import matplotlib.pyplot as plt show_freqs(sr, gids) show_weights(sr, gids) show_raster(sr, gids) plt.show()