A spiking neural network model of model-free reinforcement learning (Nakano et al 2015)

 Download zip file 
Help downloading and running models
Accession:168143
"Spiking neural networks provide a theoretically grounded means to test computational hypotheses on neurally plausible algorithms of reinforcement learning through numerical simulation. ... In this work, we use a spiking neural network model to approximate the free energy of a restricted Boltzmann machine and apply it to the solution of PORL (partially observable reinforcement learning) problems with high-dimensional observations. ... The way spiking neural networks handle PORL problems may provide a glimpse into the underlying laws of neural information processing which can only be discovered through such a top-down approach. "
Reference:
1 . Nakano T, Otsuka M, Yoshimoto J, Doya K (2015) A spiking neural network model of model-free reinforcement learning with high-dimensional sensory input and perceptual ambiguity. PLoS One 10:e0115620 [PubMed]
Citations  Citation Browser
Model Information (Click on a link to find other models with that property)
Model Type: Realistic Network;
Brain Region(s)/Organism:
Cell Type(s): Abstract integrate-and-fire leaky neuron;
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: NEST;
Model Concept(s): Reinforcement Learning;
Implementer(s): Nakano, Takashi [nakano.takashi at gmail.com];
# -*- coding: utf-8 -*-
####################################
#Nakano T, Otsuka M, Yoshimoto J, Doya K (2015) A Spiking Neural Network Model of Model-Free Reinforcement Learning with High-Dimensional Sensory Input and Perceptual Ambiguity. PLoS ONE 10(3): e0115620. doi:10.1371/journal.pone.0115620
####################################

import nest
nest.pynestkernel.logstdout("")
import nest.raster_plot
import nest.voltage_trace
import pylab
import numpy
import scipy
import time, sys
import matplotlib.pyplot as plt
import random
import operator
import math
import csv

nest.ResetKernel()
nest.SetKernelStatus({"overwrite_files": True})
####################################
class SNNc:
	def __init__(self):
		#define parameters
		self.NS=20*15 # number of state neurons
		self.NH=90 # number of hidden neurons
		self.NA=90 # number of action neurons
		self.NM=50 # number of memory neurons
		self.InitWhsmean=20 # mean weight Whs
		self.InitWhsstd=11.88  # std weight Whs
		self.InitWhamean=20 # mean weight Wha
		self.InitWhastd=11.88  # std weight Wha
		
		self.InitWhmmean=20 # mean weight Whm
		self.InitWhmstd=11.88  # std weight Whm
		self.InitWmmmean=0 # mean weight Whm
		self.InitWmmstd=0.  # std weight Whm
		
		
		self.Wseed=0#seed for weight
		self.Nseed=[123]#seed for noise
		self.driveparams_on  = {'amplitude':2000.}#current inputs to state neurons (and action neurons)
		self.driveparams_on_a  = {'amplitude':2000.}#current inputs to state neurons (and action neurons)
		self.driveparams_off_a  = {'amplitude':-5000.}#current inputs to state neurons (and action neurons)
		self.driveparams_off  = {'amplitude':0.}#no current inputs
		self.driveparams_inh  = {'amplitude':-5000.}#no current inputs
		self.noiseparams  = {'mean':0.0, 'std':300.}#noise inputs to all state and action neurons
		self.sdparams  = { "withtime": True, "withgid" : True,'to_file':False, 'to_screen':False,'flush_after_simulate':True,'flush_records':True}
		#neuronparams = { 'tau_m':20., 'V_th':-50., 'E_L':-60., 't_ref':2., 'V_reset':-60., 'C_m':200.}
		
		#create neurons
		self.Sneurons = nest.Create('iaf_neuron',self.NS)
		self.Hneurons = nest.Create('iaf_neuron',self.NH)
		self.Aneurons = nest.Create('iaf_neuron',self.NA)
		self.Mneurons = nest.Create('iaf_neuron',self.NM)
		
		self.sd= nest.Create('spike_detector')
		self.drive= nest.Create('dc_generator',self.NS+4+1)
		nest.SetKernelStatus({'rng_seeds':self.Nseed})#noise
		self.noise= nest.Create('noise_generator',6)
		self.voltmeter = nest.Create("voltmeter")
	
		#set parameters
		nest.SetStatus(self.sd,[self.sdparams] )
		nest.SetStatus(self.noise,[self.noiseparams] ) # if noise selection works, comment out this.
		#nest.SetStatus(self.Sneurons, [self.neuronparams])
		#nest.SetStatus(self.Hneurons, [self.neuronparams])
		#nest.SetStatus(self.Aneurons, [self.neuronparams])
	
		#connect
		nest.DivergentConnect(self.noise[0:1], self.Sneurons)
		nest.DivergentConnect(self.noise[1:2], self.Aneurons)
		nest.DivergentConnect(self.noise[2:3], self.Hneurons)
		nest.DivergentConnect(self.noise[3:4], self.Mneurons)
		
		for i in range(self.NS):
			nest.Connect(self.drive[i:i+1],self.Sneurons[i:i+1] )
			
		nest.DivergentConnect(self.drive[self.NS:self.NS+1], self.Aneurons[0:self.NA/4])
		nest.DivergentConnect(self.drive[self.NS+1:self.NS+2], self.Aneurons[self.NA/4:self.NA/2])
		nest.DivergentConnect(self.drive[self.NS+2:self.NS+3], self.Aneurons[self.NA/2:self.NA*3/4])
		nest.DivergentConnect(self.drive[self.NS+3:self.NS+4], self.Aneurons[self.NA*3/4:self.NA])
		
		nest.DivergentConnect(self.drive[self.NS+4:self.NS+5], self.Mneurons)
		
		nest.ConvergentConnect(self.Sneurons, self.Hneurons, weight=100.0, delay=1.0)#weight is no meaning because it is defined later
		nest.ConvergentConnect(self.Hneurons, self.Aneurons, weight=100.0, delay=1.0)#DivergentConnect?
		nest.ConvergentConnect(self.Aneurons, self.Hneurons, weight=100.0, delay=1.0)
		
		nest.ConvergentConnect(self.Mneurons, self.Hneurons, weight=100.0, delay=1.0)

		#nest.sli_run('/RandomDivergentConnect << /allow_multapses false >> SetOptions')
		#nest.sli_run('/RandomDivergentConnect << /allow_autapses false >> SetOptions')
		#nest.RandomDivergentConnect(self.Mneurons, self.Mneurons, self.NM/5, weight=100.0, delay=1.0)
		nest.ConvergentConnect(self.Mneurons, self.Mneurons, weight=0.0, delay=1.0)
		
		nest.ConvergentConnect(self.Sneurons, self.Mneurons, weight=0.0, delay=1.0)#weight is no meaning because it is defined later
		
		nest.ConvergentConnect(self.Sneurons, self.sd)
		nest.ConvergentConnect(self.Hneurons, self.sd)
		nest.ConvergentConnect(self.Aneurons, self.sd)
		nest.ConvergentConnect(self.Mneurons, self.sd)
		
		nest.Connect(self.voltmeter, self.Hneurons[20:21])
		
		#weight	
		random.seed(self.Wseed)
		
	
	
	
	
	def InitW(self):
		##init weight
		##Whs
		self.Whs=[]
		for j in range(self.NS):
			Whstemp=[]
			for i in range(self.NH):
				Whstemp.append(random.normalvariate(self.InitWhsmean, self.InitWhsstd))
			self.Whs.append(Whstemp)
		
		##Wha
		self.Wha=[]
		for j in range(self.NA):
			Whatemp=[]
			for i in range(self.NH):
				Whatemp.append(random.normalvariate(self.InitWhamean,self.InitWhastd))
			self.Wha.append(Whatemp)
			
		##Whm
		self.Whm=[]
		for j in range(self.NM):
			Whmtemp=[]
			for i in range(self.NH):
				Whmtemp.append(random.normalvariate(self.InitWhmmean,self.InitWhmstd))
			self.Whm.append(Whmtemp)
			
		
		##Wmm import
		filename ="./MMweight2301.txt"
		csvfileWmm = open(filename)
		
		self.Wmm=[[]]
		i=0
		for row in csv.reader(csvfileWmm):
			for elem in row:
				self.Wmm[i].append(float(elem))
			self.Wmm.append([])
			i=i+1
		self.Wmm.pop()
		
		#Wms import
		filenameWcd="Wcd50_noBias.txt"
		csvfileWcd=open(filenameWcd)
		
		self.Wms=[[]]
		i=0
		for row in csv.reader(csvfileWcd):
			for elem in row:
				self.Wms[i].append(float(elem))
			self.Wms.append([])
			i=i+1
		self.Wms.pop()
		
		
		
	def ConnectW_SM(self):
		
		##Wmm
		connWmm=[]
		for i in range(self.NM):
			self.Wmm[i].append(1.0)
			for j in range(self.NM):
				connWmm.append(nest.FindConnections([self.Mneurons[i]],[self.Mneurons[j]]))
				nest.SetStatus(connWmm[i*self.NM+j],['weight'][0],self.Wmm[i][j])
			self.Wmm[i].pop() 
		
	
	def ConnectW(self):
		##Whs
		temp=map(list, zip(*self.Whs))+map(list, zip(*self.Wms))
		tempM=map(list, zip(*temp))
		
		connWhs=[]
		for i in range(self.NS):
			tempM[i].append(1.0)
			connWhs.append(nest.FindConnections([self.Sneurons[i]]))
			nest.SetStatus(connWhs[i],['weight'][0],tempM[i])
			tempM[i].pop() 
		
		##Wha
		connWha=[]
		for i in range(self.NA):
			self.Wha[i].append(1.0)
			connWha.append(nest.FindConnections([self.Aneurons[i]]))
			nest.SetStatus(connWha[i],['weight'][0],self.Wha[i])
			self.Wha[i].pop() # for Wah
		
		##Wah
		self.Wah=map(list, zip(*self.Wha))#transpose
		for i in range(self.NH):
			self.Wah[i].append(1.0)
		connWah=[]
		for i in range(self.NH):
			connWah.append(nest.FindConnections([self.Hneurons[i]]))
			nest.SetStatus(connWah[i],['weight'][0],self.Wah[i])
		
		##Whm
		connWhm=[]
		for i in range(self.NM):
			self.Whm[i].append(1.0)
			for j in range(self.NH):
				connWhm.append(nest.FindConnections([self.Mneurons[i]],[self.Hneurons[j]]))
				nest.SetStatus(connWhm[i*self.NH+j],['weight'][0],self.Whm[i][j])
			self.Whm[i].pop() 
		
	
	
####################################
def Digit(state=1):
	
	if state==3:
		state= GoalState
	
	#filename="digit"+str(state)
	filename="../shrunk_digit_easy_test_20_15T/digit"+str(state)+"_"+str(random.randint(1,10))
	
	csvfile=open(filename)
	
	obs=[]
	for row in csv.reader(csvfile):
		for elem in row:
			obs.append(int(elem))
	
	return obs

####################################
def StateClamp(obs, action, inh=0):
	
	nest.SetStatus(SNN.drive,[SNN.driveparams_off] )# all drives are 0
	
	for i in range(len(obs)):
		if obs[i]==1:
			nest.SetStatus(SNN.drive[i:i+1],[SNN.driveparams_on] )
	
	if action==1:
		nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_on_a] )
		nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_off_a] )
		nest.SetStatus(SNN.drive[SNN.NS+2:SNN.NS+3],[SNN.driveparams_off_a] )
		nest.SetStatus(SNN.drive[SNN.NS+3:SNN.NS+4],[SNN.driveparams_off_a] )
	if action==2:
		nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_off_a] )
		nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_on_a] )
		nest.SetStatus(SNN.drive[SNN.NS+2:SNN.NS+3],[SNN.driveparams_off_a] )
		nest.SetStatus(SNN.drive[SNN.NS+3:SNN.NS+4],[SNN.driveparams_off_a] )
	if action==3:
		nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_off_a] )
		nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_off_a] )
		nest.SetStatus(SNN.drive[SNN.NS+2:SNN.NS+3],[SNN.driveparams_on_a] )
		nest.SetStatus(SNN.drive[SNN.NS+3:SNN.NS+4],[SNN.driveparams_off_a] )
	if action==4:
		nest.SetStatus(SNN.drive[SNN.NS:SNN.NS+1],[SNN.driveparams_off_a] )
		nest.SetStatus(SNN.drive[SNN.NS+1:SNN.NS+2],[SNN.driveparams_off_a] )
		nest.SetStatus(SNN.drive[SNN.NS+2:SNN.NS+3],[SNN.driveparams_off_a] )
		nest.SetStatus(SNN.drive[SNN.NS+3:SNN.NS+4],[SNN.driveparams_on_a] )
	
	if inh==1:
		nest.SetStatus(SNN.drive[SNN.NS+4:SNN.NS+5],[SNN.driveparams_inh] )
		nest.Simulate(T)
	
	nest.Simulate(T)





####################################
def CalcFR():
	spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array
	spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array
	
	if len(spiketimes)!=0:
		maxTime=(spiketimes[-1]//T +1)*T
		if max(spiketimes)>maxTime-100:
			IdxLastSP=numpy.nonzero(spiketimes>maxTime-100)[0][0]# last 100 ms
			spcount=spikesender[IdxLastSP:]
			spcount2=sorted(spcount)
			fr=[]
			for i in range(SNN.NS+SNN.NH+SNN.NA+1):
				fr.append(spcount2.count(i))
		else:
			fr=numpy.zeros(SNN.NS+SNN.NH+SNN.NA+1)
	else:
		fr=numpy.zeros(SNN.NS+SNN.NH+SNN.NA+1)
	#note
	#len(fr)=SNN.NS+SNN.NH+SNN.NA+1
	#fr[0] is no meaning
	
	hiddenFR=fr[SNN.NS+1:SNN.NS+SNN.NH+1]
	actionFR=[]
	actionFR.append(sum(fr[SNN.NS+SNN.NH+1:SNN.NS+SNN.NH+SNN.NA/4+1]))
	actionFR.append(sum(fr[SNN.NS+SNN.NH+SNN.NA/4+1:SNN.NS+SNN.NH+SNN.NA*2/4+1]))
	actionFR.append(sum(fr[SNN.NS+SNN.NH+SNN.NA*2/4+1:SNN.NS+SNN.NH+SNN.NA*3/4+1]))
	actionFR.append(sum(fr[SNN.NS+SNN.NH+SNN.NA*3/4+1:SNN.NS+SNN.NH+SNN.NA+1]))
	return hiddenFR, actionFR


####################################
def Actionselection(actionFR,episode,state):
	action=1
	
	if state==3:
		beta=Beta*float(episode)/float(Nepisode)+0.5/6.
		if random.random()<(math.exp(beta*actionFR[1]))/(math.exp(beta*actionFR[1])+math.exp(beta*actionFR[3])):# +1 is to avoid /0
			action=2
		else:
			action=4
	
	return action

####################################
def StateTrans(state,action,goalflag, length):
	reward=-500.
	goal=0
	
	if state==0:
		if action==1:
			nextState=2
		else:
			nextState=0
	if state==1:
		if action==1:
			nextState=2
		else:
			nextState=1
	
	if state==2:
		if action==1:
			if length==0:
				nextState=3
			else:
				nextState=2
			nextState=3
		elif action==3:
			if InitState==0:
				nextState=0
			if InitState==1:
				nextState=1
		else:
			nextState=2
	
	if state==3:
		if action==1:
			nextState=3
		if action==2:
			goal=1
			nextState=4
			if InitState==GoalState:
				reward=20000.
				goalflag=1
		if action==4:
			goal=1
			nextState=4
			if InitState!=GoalState:
				reward=20000.
				goalflag=1
		if action==3:
			nextState=2
			
	length=length-1
	
	return nextState, reward, goal, goalflag, length
	
####################################
def CalcFE_AVE(binSP, state, action, whs, wha, whm):
	#h_hat
	maxFR=50.#47 spikes per 1 neuron for 100 ms. 50 is for safety
	#this can be considered as mean over bins if maxFR is 50
	bins=50
	s_hat=[]
	a_hat=[]
	h_hat=[]
	m_hat=[]
	for i in range(0,SNN.NS):
		s_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
	for i in range(SNN.NS,SNN.NS+SNN.NH):
		h_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
	for i in range(SNN.NS+SNN.NH,SNN.NS+SNN.NH+SNN.NA):
		a_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
	for i in range(SNN.NS+SNN.NH+SNN.NA,SNN.NS+SNN.NH+SNN.NA+SNN.NM):
		m_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error


	#calc entropy
	entropy=0
	for x in h_hat:
		entropy=entropy-x*math.log(x)-(1-x)*math.log(1-x)
		
	#calc expected energy
	##create state array
	sarray=[0]*SNN.NS
	if state<3:
		sarray[state*16:state*16+16]=[1]*16
	else:
		sarray[(state-1)*16:(state-1)*16+16]=[1]*16
	sarraya=numpy.array(sarray)
	
	##create action array
	aarray=[0]*SNN.NA
	aarray[(action+1)/2*50:(action+1)/2*50+50]=[1]*50
	aarraya=numpy.array(aarray)
	#create matrices
	smat=scipy.mat(binSP[0:SNN.NS][:])
	amat=scipy.mat(binSP[SNN.NS+SNN.NH:SNN.NS+SNN.NH+SNN.NA][:])
	hmat=scipy.mat(binSP[SNN.NS:SNN.NS+SNN.NH][:])
	mmat=scipy.mat(binSP[SNN.NS+SNN.NH+SNN.NA:SNN.NS+SNN.NH+SNN.NA+SNN.NM][:])
	
	whsmat=scipy.mat(whs)
	whamat=scipy.mat(wha)
	whmmat=scipy.mat(whm)
	#ExpEnergy
	if Flagsa==1:
		expEnergy_s=-sarraya*whsmat*hmat
		expEnergy_a=-aarraya*whamat*hmat
		temp2=expEnergy_a+expEnergy_s
		
	else:
		expEnergy_s=-smat.T*whsmat*hmat
		expEnergy_a=-amat.T*whamat*hmat
		expEnergy_m=-mmat.T*whmmat*hmat
		temp=expEnergy_a+expEnergy_s+expEnergy_m
		temp2=temp.diagonal()	
	expEnergy=temp2.tolist()[0]## convert array back to Python list
	expEnergy_mean=sum(expEnergy)/bins
	#FreeEnergy
	freeEnergy=-entropy+expEnergy_mean
	
	if Flagsa==1:
		return sarray, aarray, h_hat, entropy, expEnergy_mean, freeEnergy
	else:
		return s_hat, a_hat, h_hat, m_hat, entropy, expEnergy_mean, freeEnergy
####################################
def CalcFE_LPF(binSP, state, action, whs, wha, whm):
	alpha_h=0.1
	alpha_f=0.1
	s_hat=[]
	a_hat=[]
	h_hat=[]
	m_hat=[]
	maxFR=50.#47 spikes per 1 neuron for 100 ms. 50 is for safety
	bins=50

	for i in range(0,SNN.NS):
		s_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
	for i in range(SNN.NS+SNN.NH,SNN.NS+SNN.NH+SNN.NA):
		a_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
	for i in range(SNN.NS+SNN.NH+SNN.NA,SNN.NS+SNN.NH+SNN.NA+SNN.NM):
		m_hat.append(sum(binSP[i])/maxFR+0.001)#0.001 is to avoid error
			
	h_trace=[[]]
	bin_h=binSP[SNN.NS:SNN.NS+SNN.NH]
	#binSP_T=map(list, zip(*binSP))#transposition
	for i in range(len(bin_h)):
		for j in range(bins):
			if j==0:
				h_trace[i].append(0.5)
			else:
				h_trace[i].append((1-alpha_h)*h_trace[i][j-1]+alpha_h*bin_h[i][j])
		h_hat.append(h_trace[i][-1])
		h_trace.append([])
	h_trace.pop()
	
	
	#calc entropy
	entropy=[[]]
	for i in range(len(h_trace[0])):
		for j in range(len(h_trace)):
			if j==0:
				entropy[i]=-bin_h[j][i]*math.log(h_trace[j][i])-(1.-bin_h[j][i])*math.log(1.-h_trace[j][i])
			else:
				entropy[i]=entropy[i]-bin_h[j][i]*math.log(h_trace[j][i])-(1.-bin_h[j][i])*math.log(1.-h_trace[j][i])
		entropy.append([])
	entropy.pop()
	
	#calc expected energy
	#create matrices
	smat=scipy.mat(binSP[0:SNN.NS][:])
	amat=scipy.mat(binSP[SNN.NS+SNN.NH:SNN.NS+SNN.NH+SNN.NA][:])
	hmat=scipy.mat(binSP[SNN.NS:SNN.NS+SNN.NH][:])
	mmat=scipy.mat(binSP[SNN.NS+SNN.NH+SNN.NA:SNN.NS+SNN.NH+SNN.NA+SNN.NM][:])
	whsmat=scipy.mat(whs)
	whamat=scipy.mat(wha)
	whmmat=scipy.mat(whm)
	#ExpEnergy
	expEnergy_s=-smat.T*whsmat*hmat
	expEnergy_a=-amat.T*whamat*hmat
	expEnergy_m=-mmat.T*whmmat*hmat
	temp=expEnergy_a+expEnergy_s+expEnergy_m
	temp2=temp.diagonal()
	expEnergy=temp2.tolist()[0]## convert array back to Python list
	#f
	f=[]
	for i in range(bins):
		f.append(-entropy[i]+expEnergy[i])
	
	#FreeEnergy
	for i in range(bins):
		if i==0:
			freeEnergy_t=[f[0]]
		else:
			freeEnergy_t.append(freeEnergy_t[i-1]+alpha_f*(f[i]-freeEnergy_t[i-1]))
	
	#return sarray, aarray, h_hat, entropy, expEnergy_mean, freeEnergy
	return s_hat, a_hat, h_hat,m_hat, entropy[-1], expEnergy[-1], freeEnergy_t[-1],freeEnergy_t

####################################
def UpdateW(freeEnergy_1,freeEnergy, reward_1, reward, goal, sarray_1, aarray_1, marray_1, h_hat, h_hat_1):
	if goal==0:
		deltaWhs=[]
		for i in range(len(sarray_1)):
			temp=[]
			for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE
				temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*sarray_1[i]*x)
			deltaWhs.append(temp)
		
		deltaWha=[]
		for i in range(len(aarray_1)):
			temp=[]
			for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE
				temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*aarray_1[i]*x)
			deltaWha.append(temp)

		deltaWhm=[]
		for i in range(len(marray_1)):
			temp=[]
			for x in h_hat_1:#Notice! h_hat was added 0.001 in CalcFE
				temp.append((reward_1-Gamma*freeEnergy+freeEnergy_1)*marray_1[i]*x)
			deltaWhm.append(temp)
	
	else:
		deltaWhs=[]
		for i in range(len(sarray_1)):
			temp=[]
			for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE
				temp.append((reward+freeEnergy)*sarray_1[i]*x)
			deltaWhs.append(temp)
		
		deltaWha=[]
		for i in range(len(aarray_1)):
			temp=[]
			for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE
				temp.append((reward+freeEnergy)*aarray_1[i]*x)
			deltaWha.append(temp)

		deltaWhm=[]
		for i in range(len(marray_1)):
			temp=[]
			for x in h_hat:#Notice! h_hat was added 0.0001 in CalcFE
				temp.append((reward+freeEnergy)*marray_1[i]*x)
			deltaWhm.append(temp)
			
		goal=0
	return deltaWhs, deltaWha, deltaWhm


####################################
def WinputSNN(deltaWhs, deltaWha, deltaWhm):
	whsmat=scipy.mat(SNN.Whs)#to matrix
	whamat=scipy.mat(SNN.Wha)
	whmmat=scipy.mat(SNN.Whm)
	deltaWhsmat=scipy.mat(deltaWhs)
	deltaWhamat=scipy.mat(deltaWha)
	deltaWhmmat=scipy.mat(deltaWhm)
	
	newWhsmat=whsmat+Alpha*deltaWhsmat
	newWhamat=whamat+Alpha*deltaWhamat
	newWhmmat=whmmat+Alpha*deltaWhmmat
	
#	newWhs=newWhsmat.tolist()#to list
#	newWha=newWhamat.tolist()
	
	SNN.Whs=newWhsmat.tolist()# if we use SNN.Whs=NewWhs[:] instead of this,
	SNN.Wha=newWhamat.tolist()#len(SNN.Whs[0]) become +1 after NewWhs[i].append(1.0)
	SNN.Whm=newWhmmat.tolist()

	SNN.ConnectW()

####################################
def Outputsd():
	spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array
	spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array
	spikesender.tofile("Spikesender.txt", sep=', ', format = "%e") 
	spiketimes.tofile("Spiketimes.txt", sep=', ', format = "%e") 

####################################
def BinSPcnt():
	spikesender=nest.GetStatus(SNN.sd,['events'][0])[0]['senders']#array
	spiketimes=nest.GetStatus(SNN.sd,['events'][0])[0]['times']#array
	maxTime=(spiketimes[-1]//(2*T) +1)*2*T
	bins=50
	binsize=2	
	binidx=[]
	errorflag=0
	for i in range(bins):
		if max(spiketimes)>maxTime-101+i*binsize:
			binidx.append(numpy.nonzero(spiketimes>maxTime-101+i*binsize)[0][0])
			#nonzero is like "find" in matlab
			#[0][0] to get the first index
			j=i
		else:
			if len(binidx)==0: #avoid error that j does not exist, that is no spikes
				errorflag=1
			if len(binidx)!=0: 
				binidx.append(binidx[j])
				
	if errorflag==0:
		binspikes=[]
		for i in range(bins-1):
			binspikes.append(sorted(spikesender[binidx[i]:binidx[i+1]]))
		binspikes.append(sorted(spikesender[binidx[bins-1]:]))
		
		binSP=[[]]
		for i in range(SNN.NS+SNN.NH+SNN.NA+SNN.NM):
			for j in range(bins):
				if i+1 in binspikes[j]:
					binSP[i].append(1)
				else:
					binSP[i].append(0)
			if i != SNN.NS+SNN.NH+SNN.NA+SNN.NM-1:
				binSP.append([])
	else:
		binSP=[[0]*bins]*(SNN.NS+SNN.NH+SNN.NA+SNN.NM)
	#binSP=map(list, zip(*binSP))#transposition
				
	return binSP
	
	
####################################
def ReadW():
	
	filenameWha="Wha.txt"
	csvfileWha=open(filenameWha)
	
	wha=[[]]
	i=0
	for row in csv.reader(csvfileWha):
		for elem in row:
			wha[i].append(float(elem))
		wha.append([])
		i=i+1
	wha.pop()
	
	filenameWhs="Whs.txt"
	csvfileWhs=open(filenameWhs)
	
	whs=[[]]
	i=0
	for row in csv.reader(csvfileWhs):
		for elem in row:
			whs[i].append(float(elem))
		whs.append([])
		i=i+1
	whs.pop()
	
	filenameWhm="Whm.txt"
	csvfileWhm=open(filenameWhm)
	
	whm=[[]]
	i=0
	for row in csv.reader(csvfileWhm):
		for elem in row:
			whm[i].append(float(elem))
		whm.append([])
		i=i+1
	whm.pop()
	
	SNN.Wha=wha
	SNN.Whs=whs
	SNN.Whm=whm
	
	SNN.ConnectW()
	
####################################
####################################
#main
SNN=SNNc()
Flag_FE=0#1 is LPF
Flagsa=0#1 for using binary state and action in FEAVE
T=500
Nepoch=1
Nepisode=3000
Maxstep=30
Gamma=0.99
Alpha=0.0001
Beta=0.1/6.
HistoryReward=[]
HistoryNstep=[]
HistoryCumR=[]
Historys=[]
Historya=[]
HistoryFE=[]
HistoryFEts=[]
HistoryFEtg=[]
HistoryFR=[]
	
for Epoch in range(Nepoch):
	nest.SetKernelStatus({'time':0.0})
	Reward_info=[]
	Nstep=[]
	HistoryCumRtemp=[]
	Historystemp2=[]
	Historyatemp2=[]
	HistoryFEtemp2=[]
	HistoryFRtemp2=[]
	
	SNN.InitW()
	SNN.ConnectW_SM()
	SNN.ConnectW()
	
	ReadW()
	
	for Episode in range(Nepisode):
		GoalFlag=0
		Goal=0
		Action=0
		InitState=Episode%2
		GoalState=random.randint(0,1)
		Length=random.randint(0,4)
		State=InitState
		State_1=InitState
		Action_1=1
		FreeEnergy=0.
		FreeEnergy_1=0.
		Reward=0.
		Reward_1=0.
		CumR=0
		Historystemp=[]
		Historyatemp=[]
		HistoryFEtemp=[]
		HistoryFRtemp=[]
		
		Obs=Digit(State)
		StateClamp(Obs, 0,1) #run SNN
		
		for Step in range(Maxstep):
			Obs_1=Obs[:]
			Obs=Digit(State)
			StateClamp(Obs, 0) #run SNN
			[HiddenFR, ActionFR]=CalcFR()
			Action= Actionselection(ActionFR,Episode,State)
			print("State", State, "Action", Action)
			StateClamp(Obs, Action)# for FE
			#[HiddenFR, ActionFR]=CalcFR()
			BinSP=BinSPcnt()
			if Step==0:
				State_1=State
				Action_1=Action
				#HiddenFR_1=HiddenFR
				BinSP_1=BinSP[:][:]
				Whs_1=SNN.Whs[:][:]
				Wha_1=SNN.Wha[:][:]
				Whm_1=SNN.Whm[:][:]
			if Flag_FE==1:
				[Sarray, Aarray, H_hat, Marray, Entropy, ExpEnergy, FreeEnergy, FreeEnergy_t]=CalcFE_LPF(BinSP, State, Action, Whs_1, Wha_1, Whm_1)#calc FE
				[Sarray_1, Aarray_1, H_hat_1, Marray_1, Entropy_1,ExpEnergy_1, FreeEnergy_1, FreeEnergy_t_1]=CalcFE_LPF(BinSP_1, State_1, Action_1, Whs_1, Wha_1, Whm_1)#calc FE
			else:
				[Sarray, Aarray, H_hat, Marray, Entropy, ExpEnergy, FreeEnergy]=CalcFE_AVE(BinSP, State, Action, Whs_1, Wha_1, Whm_1)#calc FE
				[Sarray_1, Aarray_1, H_hat_1, Marray_1, Entropy_1,ExpEnergy_1, FreeEnergy_1]=CalcFE_AVE(BinSP_1, State_1, Action_1, Whs_1, Wha_1, Whm_1)#calc FE
			State_1=State###
			Action_1=Action###
			Reward_1=Reward	###
			#HiddenFR_1=HiddenFR###
			BinSP_1=BinSP[:][:]
			Whs_1=SNN.Whs[:][:]	###
			Wha_1=SNN.Wha[:][:]	###
			Whm_1=SNN.Whm[:][:]	###
			[State, Reward, Goal, GoalFlag, Length]=StateTrans(State, Action, GoalFlag, Length)# move
			[DeltaWhs, DeltaWha, DeltaWhm]=UpdateW(FreeEnergy_1,FreeEnergy, Reward_1, Reward, 0, Sarray_1,  Aarray_1,  Marray_1, H_hat, H_hat_1)
			Print1="Epoch: %d, Episode: %d, Step: %d, Flag: %d"
			Print2="New state: %d, State: %d, Action: %d, Reward: %d, Goal: %d, CumR=%.3f"
			Print3="FreeEnergy(s=%d, a=%d)=%.3f, Entropy=%.3f, ExpEnergy=%.3f"
			Print4="ActionFR1=%d, ActionFR2=%d, ActionFR3=%d, ActionFR4=%d"
			print Print1 % (Epoch+1, Episode+1, Step+1, GoalFlag)
			print Print2 % (State, State_1, Action_1, Reward, Goal, CumR)
			print Print3 % (State_1, Action_1, FreeEnergy, Entropy, ExpEnergy)
			print Print4 % (ActionFR[0], ActionFR[1], ActionFR[2], ActionFR[3])
			print("Reward",Reward_info)
			if Step==0:
				Historystemp.append(State_1)
			Historystemp.append(State)
			Historyatemp.append(Action)
			HistoryFEtemp.append(FreeEnergy)
			HistoryFRtemp.append(ActionFR)
			if Step==0 and Epoch ==0 and Flag_FE==1:
				HistoryFEts.append(FreeEnergy_t)
			if Step != 0:
				WinputSNN(DeltaWhs, DeltaWha, DeltaWhm)#not run 1st move
				CumR=CumR+Reward_1*Gamma**(Step-1)
			if Goal == 1:
				CumR=CumR+Reward*Gamma**(Step+1)
				Whs=SNN.Whs[:][:]	####not need if ResetNetwork is off
				Wha=SNN.Wha[:][:]	####not need if ResetNetwork is off
				Whm=SNN.Whm[:][:]	####not need if ResetNetwork is off
				if Epoch != Nepoch-1 or Episode != Nepisode-1:
					nest.ResetNetwork()#############
				SNN.Whs=Whs_1[:][:]	####not need if ResetNetwork is off
				SNN.Wha=Wha_1[:][:]	####not need if ResetNetwork is off
				SNN.Whm=Whm_1[:][:]	####not need if ResetNetwork is off
				[DeltaWhs, DeltaWha, DeltaWhm]=UpdateW(FreeEnergy_1,FreeEnergy, Reward_1, Reward, 1, Sarray, Aarray, Marray, H_hat, H_hat_1)
				SNN.Whs=Whs[:][:]	####not need if ResetNetwork is off
				SNN.Wha=Wha[:][:]	####not need if ResetNetwork is off
				SNN.Whm=Whm[:][:]	####not need if ResetNetwork is off
				WinputSNN(DeltaWhs, DeltaWha, DeltaWhm)#not run 1st move
				break
			
		Historystemp2.append(Historystemp)
		Historyatemp2.append(Historyatemp)
		HistoryFEtemp2.append(HistoryFEtemp)
		HistoryFRtemp2.append(HistoryFRtemp)
		Reward_info.append(GoalFlag)
		Nstep.append(Step+1)
		HistoryCumRtemp.append(CumR)
		if Epoch ==0 and Flag_FE==1:
			if Goal==0:
				HistoryFEtg.append([0 for _ in range(50)])
			else:
				HistoryFEtg.append(FreeEnergy_t)


	Historys.append(Historystemp2)
	Historya.append(Historyatemp2)
	HistoryFE.append(HistoryFEtemp2)
	HistoryFR.append(HistoryFRtemp2)
	HistoryNstep.append(Nstep)
	HistoryReward.append(Reward_info)
	HistoryCumR.append(HistoryCumRtemp)


#to file
fHistory= open('History.txt', 'w')
for i in range(Nepoch):
	for j in range(Nepisode):
		fHistory.write("Epoch:%d, Episode: %d\n" % (i+1, j+1))		
		fHistory.write(str(Historys[i][j]) )
		fHistory.write('\n')
		fHistory.write(str(Historya[i][j]) )
		fHistory.write('\n')
		fHistory.write(str(HistoryFE[i][j]) )
		fHistory.write('\n')
		fHistory.write(str(HistoryFR[i][j]) )
		fHistory.write('\n')
	fHistory.write('\n')
fHistory.close()

fHistoryReward= open('HistoryReward.txt', 'w')
fHistoryNstep= open('HistoryNstep.txt', 'w')
fHistoryCumR= open('HistoryCumR.txt', 'w')
for i in range(Nepoch):
	for j in range(Nepisode):
		fHistoryReward.write(str(HistoryReward[i][j]))
		fHistoryNstep.write(str(HistoryNstep[i][j]))
		fHistoryCumR.write(str(HistoryCumR[i][j]))
		if j !=Nepisode-1:
			fHistoryReward.write(', ')
			fHistoryNstep.write(', ')
			fHistoryCumR.write(', ')
	fHistoryReward.write('\n')
	fHistoryNstep.write('\n')
	fHistoryCumR.write('\n')
fHistoryReward.close()
fHistoryNstep.close()
fHistoryCumR.close()

fWhs= open('Whs.txt', 'w')
for i in range(SNN.NS):
	for j in range(SNN.NH):
		fWhs.write(str(SNN.Whs[i][j]))
		if j !=SNN.NH-1:
			fWhs.write(', ')
	fWhs.write('\n')
fWhs.close()

fWha= open('Wha.txt', 'w')
for i in range(SNN.NA):
	for j in range(SNN.NH):
		fWha.write(str(SNN.Wha[i][j]))
		if j !=SNN.NH-1:
			fWha.write(', ')
	fWha.write('\n')
fWha.close()

fWhm= open('Whm.txt', 'w')
for i in range(SNN.NM):
	for j in range(SNN.NH):
		fWhm.write(str(SNN.Whm[i][j]))
		if j !=SNN.NH-1:
			fWhm.write(', ')
	fWhm.write('\n')
fWhm.close()

if Flag_FE==1:
	fFEts= open('FEts.txt', 'w')
	fFEtg= open('FEtg.txt', 'w')
	for i in range(Nepisode):
		for j in range(50):
			fFEts.write(str(HistoryFEts[i][j]))
			fFEtg.write(str(HistoryFEtg[i][j]))
			if j !=50-1:
				fFEts.write(', ')
				fFEtg.write(', ')
		fFEts.write('\n')
		fFEtg.write('\n')
	fFEts.close()
	fFEtg.close()



#nest.SetKernelStatus({'time':0.0})
#nest.SetStatus(SNN.voltmeter,[{"to_file": True, "withtime": True}])
#StateClamp(0)
#StateClamp(0,-1)
#StateClamp(0,0,0,-1)
#StateClamp(0,1,0,-1)
#StateClamp(1,0,0,1)
#StateClamp(1,1,0,1)
nest.raster_plot.from_device(SNN.sd, hist=False)
#plt.xlim( (0, 1000) )
##plt.show()
plt.savefig('raster.eps')
#plt.close()
##nest.voltage_trace.from_device(SNN.voltmeter)