from Controller import * import matplotlib.pyplot as plt import pylab from time import time ########### ##@summary: This is to test the network saved in **.obj, which is generated during training by rumMe.py ##@author: Akihiro Eguchi ##Aug 13, 2013 nTrainings = 2001 ndim = 30 dim_stim = 10 controller = Controller(dim_stim,ndim) resultFolderName = "results"; controller.resultFolderName = resultFolderName; transOn = 1 singleColor = 1 alteringInput = 0#not fully implemented yet itr = 2000 class InputStim: r=0 g=0 b=0 def setRGB(self,r,g,b): self.r= r self.g= g self.b= b inputStim = [[InputStim() for x in xrange(dim_stim)] for x in xrange(dim_stim)] controller.loadWeightsAndDelays("Network_"+str(itr)+".obj",resultFolderName) controller.setLearningStates(0)#stop synaptic modifications controller.variables.tstop = 300; nThreads = 7; controller.pc = h.ParallelContext() controller.pc.nthread(nThreads) if(alteringInput): controller.AlteringInputInit(); weightTemp_LtoL4 = [] weightTemp_C1toL4 = [] weightTemp_C2toL23 = [] weightTemp_L4toL23 = [] weightTemp_L23toL5 = [] for index in range(len(controller.NetCons_STDP_LtoL4)): weightTemp_LtoL4.append(controller.NetCons_STDP_LtoL4[index].weight[0]) weightTemp_C1toL4.append(controller.NetCons_STDP_C1toL4[index].weight[0]) weightTemp_C2toL23.append(controller.NetCons_STDP_C2toL23[index].weight[0]) for index in range(len(controller.NetCons_STDP_L4toL23)): weightTemp_L4toL23.append(controller.NetCons_STDP_L4toL23[index].weight[0]) weightTemp_L23toL5.append(controller.NetCons_STDP_L23toL5[index].weight[0]) plt.subplot(5, 1, 1) plt.hist(weightTemp_L23toL5, bins=100, range=[0, 0.005*3]); plt.xlim(0,0.005*3) plt.xlabel("synaptic weights between V1_L23 and V1_L5") plt.ylabel("number of synapses") frame1 = plt.gca() frame1.axes.get_xaxis().set_visible(False) # plt.show(); plt.subplot(5, 1, 2) plt.hist(weightTemp_C2toL23, bins=100, range=[0, 0.005*1.5]); plt.xlim(0,0.005*1.5) plt.xlabel("synaptic weights between C2 and V1_L23") plt.ylabel("number of synapses") frame1 = plt.gca() frame1.axes.get_xaxis().set_visible(False) plt.subplot(5, 1, 3) plt.hist(weightTemp_L4toL23, bins=100, range=[0, 0.005*1.5]); plt.xlim(0,0.005*1.5) plt.xlabel("synaptic weights between V1_L4 and V1_L23") plt.ylabel("number of synapses") frame1 = plt.gca() frame1.axes.get_xaxis().set_visible(False) plt.subplot(5, 1, 4) plt.hist(weightTemp_C1toL4, bins=100, range=[0, 0.005]); plt.xlim(0,0.005) plt.xlabel("synaptic weights between C1 and V1_L4") plt.ylabel("number of synapses") frame1 = plt.gca() frame1.axes.get_xaxis().set_visible(False) # plt.show(); plt.subplot(5, 1, 5) plt.hist(weightTemp_LtoL4, bins=100, range=[0, 0.005]); plt.xlim(0,0.005) plt.xlabel("synaptic weights between L and V1_L4") plt.ylabel("number of synapses") frame1 = plt.gca() frame1.axes.get_xaxis().set_visible(False) plt.show(); #0.5 0 0.5 #0 0 1 #0 1 1 #0 1 0 #1 1 0 #1 0.5 0 #1 0 0 #1 0 1 # purple 0.5 0 1 # blue 0 0 1 # light-green 0 1 0 # light-blue 0 1 1 # red 1 0 0 # pink 1 0 1 # yellow 1 1 0 # orange 1 0.5 0 if(singleColor==1): fig1 = plt.gcf() plt.clf() for r in range(2): r_bak = r for g in range(2): g_bak = g for b in range(2): r2 = r; g2 = g; b2 = b; if(r==0 and g==0 and b==0): r2=0.5 g2=0 b2=1 if(r==1 and g==1 and b==1): r2=1 g2=0.5 b2=0 for y in range(dim_stim): for x in range(dim_stim): inputStim[y][x].setRGB(r2,g2,b2) if(alteringInput): controller.setAlteringInput(inputStim, 0.25) else: controller.setInput(inputStim,0.8) controller.recordVols() # controller.recordChannelVols() controller.run() controller.updateSpikeCount() # controller.outputFR(itr) # controller.saveSpikeDetails(r,g,b,itr) # controller.saveChannelSpikeDetails(r,g,b,itr) plt.subplot(8,4,r*4+g*2+b+1) plt.imshow(controller.spikeCount_L5,cmap=pylab.gray()) plt.colorbar() plt.subplot(8,4,r*4+g*2+b+13) plt.imshow(controller.spikeCount_L23,cmap=pylab.gray()) plt.colorbar() plt.subplot(8,4,r*4+g*2+b+25) plt.imshow(controller.spikeCount_L4,cmap=pylab.gray()) plt.colorbar() if(transOn): controller.outputFR_trans(r2,g2,b2,itr) #transformation: varies input with similar colours modVal = 0.01 if r2 == 0: rMod = r2+modVal else: rMod = r2-modVal if g2 == 0: gMod = g2+modVal else: gMod = g2-modVal if b2 == 0: bMod = b2+modVal else: bMod = b2-modVal for y in range(dim_stim): for x in range(dim_stim): inputStim[y][x].setRGB(rMod,g2,b2) controller.setInput(inputStim) controller.recordVols() controller.run() controller.updateSpikeCount() controller.outputFR_trans(r2,g2,b2,itr) for y in range(dim_stim): for x in range(dim_stim): inputStim[y][x].setRGB(r2,gMod,b2) controller.setInput(inputStim) controller.recordVols() controller.run() controller.updateSpikeCount() controller.outputFR_trans(r2,g2,b2,itr) for y in range(dim_stim): for x in range(dim_stim): inputStim[y][x].setRGB(r2,g2,bMod) controller.setInput(inputStim) controller.recordVols() controller.run() controller.updateSpikeCount() controller.outputFR_trans(r2,g2,b2,itr) for y in range(dim_stim): for x in range(dim_stim): inputStim[y][x].setRGB(rMod,gMod,b2) controller.setInput(inputStim) controller.recordVols() controller.run() controller.updateSpikeCount() controller.outputFR_trans(r2,g2,b2,itr) for y in range(dim_stim): for x in range(dim_stim): inputStim[y][x].setRGB(r2,gMod,bMod) controller.setInput(inputStim) controller.recordVols() controller.run() controller.updateSpikeCount() controller.outputFR_trans(r2,g2,b2,itr) for y in range(dim_stim): for x in range(dim_stim): inputStim[y][x].setRGB(rMod,g2,bMod) controller.setInput(inputStim) controller.recordVols() controller.run() controller.updateSpikeCount() controller.outputFR_trans(r2,g2,b2,itr) for y in range(dim_stim): for x in range(dim_stim): inputStim[y][x].setRGB(rMod,gMod,bMod) controller.setInput(inputStim) controller.recordVols() controller.run() controller.updateSpikeCount() controller.outputFR_trans(r2,g2,b2,itr) fig1.savefig(resultFolderName+"/"+str(itr),dpi=100) plt.show() else: # controller.variables.tstop = 300 fig1 = plt.gcf() plt.clf() b = 0 for r in range(2): for g in range(2): if (r==g): continue for y in range(dim_stim): for x in range(dim_stim): if y>dim_stim/3 and ydim_stim/3 and x>dim_stim*2/3: inputStim[y][x].setRGB(1-r,1-g,b) else: inputStim[y][x].setRGB(r,g,b) controller.setInput(inputStim,0.3) controller.recordVols() controller.run() controller.updateSpikeCount() controller.outputFR(itr) plt.subplot(2,1,r+1) plt.imshow(controller.spikeCount_L4,cmap=pylab.gray()) plt.colorbar() #controller.drawGraph() controller.saveSpikeDetails(r,g,b,111110); fig1.savefig(resultFolderName+"/multiColTest300_normal"+str(itr),dpi=100) controller.setLearningStates(1)#start synaptic modifications # raw_input("Press Enter to exit...")