function[Starts Goals Decoded Vec_l Error LLA_Time] = LinearLookAheadModel(GC_cpm,DrawFig) %% Linear look ahead model of vector navigation with grid cells % Daniel Bush, UCL Institute of Cognitive Neuroscience % Reference: Using Grid Cells for Navigation (2015) Neuron (in press) % Contact: drdanielbush@gmail.com % % Inputs: % GC_cpm = Number of grid cells per unique phase, in each module % DrawFig = Plot figure of errors (0 / 1) % % Outputs: % Starts = Random 2D start locations (m) % Goals = Random 2D goal locations (m) % Decoded = 2D translation vector decoded from grid cell activity (m) % Vec_l = Length of decoded 2D translation vector (m) % Error = Error in decoded translation vector (m) % LLA_Time = Time taken to complete linear look-ahead (s) % Provide some parameters for the simulation iterations = 10; % How many iterations to run? Range = 500; % Size of environment (m) GC_mps = 20; % Unique grid cell phases on each axis, per module GC_scales = 0.25.*1.4.^(0:9); % Grid cell scales (m) GC_r = 30; % Peak grid cell firing rate (Hz) Emax_k = 0.01; % Place cell WTA parameter dt = 0.005; % Place cell synaptic integration timestep (s) SWR_speed = 8; % Speed of linear look ahead sweep (m/s) % Compute the linear look ahead spatial resolution and grid module scales Dist_step = SWR_speed * dt; % Displacement increment for place cells (m) Distances = 0 : Dist_step : Range; % Assign place cell distance coding (m) N_place = length(Distances); % Total number of place cells N_grid = length(GC_scales)*GC_mps; % Total number of grid cell phase offsets clear Dist_step % Generate synaptic weight matrices Grid_Place_w = zeros(N_grid,N_place); for scale = 1 : length(GC_scales) for offset = 1 : GC_mps Grid_Place_w((scale-1)*GC_mps + offset, 1 : N_place) = ((cos((mod(Distances-((offset-1)/GC_mps)*GC_scales(scale),GC_scales(scale))/GC_scales(scale))*2*pi)+1)/2); end end clear scale offset % Assign some memory Starts = nan(iterations,2); % Log of start positions Goals = nan(iterations,2); % Log of goal positions Decoded = nan(iterations,2); % Log of active vector cells on each axis Error = nan(iterations,1); % Log of distance error for each computed vector Vec_l = nan(iterations,1); % Log of true vector lengths LLA_Time = nan(iterations,2); % Log of time taken for each linear look ahead event % Then, for each iteration... for i = 1 : iterations % Update the user if mod(i,iterations/10)==0 disp([int2str(i/iterations*100) '% complete...']); drawnow end % Randomly assign start and goal locations and identify the place cells % encoding the goal location on each axis Starts(i,:) = [Range*rand Range*rand]; Goals(i,:) = [Range*rand Range*rand]; Goal_ind = [find(abs(Distances-Goals(i,1)) == min(abs(Distances-Goals(i,1)))) ... find(abs(Distances-Goals(i,2)) == min(abs(Distances-Goals(i,2))))]; % For each axis... for ax = 1 : 2 % Compute the phase of each grid cell at the starting location Phase = (repmat(((mod(Starts(i,ax),GC_scales)./GC_scales)*GC_mps)',1,GC_mps) - (meshgrid(1:GC_mps,1:length(GC_scales))-1))/GC_mps*2*pi; Phase = repmat(Phase,[1 1 2]); % Then run the dynamics finished = [0 0]; found = 0; dirs = [-1 1]; t = 1; while sum(finished)<2 && found==0 % For each direction along the axis... for dir = 1 : length(dirs) % If the linear look ahead activity has not been terminated if finished(dir)==0 % Compute the firing rate of place cells in each direction Rates = reshape((1+cos(Phase(:,:,dir)'))/2,length(GC_scales)*GC_mps,1); % Compute the grid cell firing rate function Rates = sum(poissrnd(repmat(Rates*GC_r*dt,[1 1 GC_cpm])),3); % Convert to Poisson spikes Rates = Rates' * Grid_Place_w; % Convert to place cell firing rates Rates = Rates .* (Rates >= (1-Emax_k)*max(Rates)); % Implement the WTA algorithm % Check the firing rate in the goal place cells and % those at the end of the place cell output axis if Rates(Goal_ind(ax))>0 Decoded(i,ax) = dirs(dir)*(t-1)*dt*SWR_speed; % Record the length of the linear look ahead thus far LLA_Time(i,ax) = t*dt; found = 1; elseif Rates(1) > 0 || Rates(end) > 0 % If linear look ahead has reached the range of the place cells finished(dir) = 1; % Terminate the linear look ahead in that direction on that axis end clear Rates if t == length(Distances) finished = [1 1]; end % Increment the phase of grid cell firing Phase(:,:,dir) = Phase(:,:,dir) + repmat((dirs(dir)*(SWR_speed*dt)./GC_scales*2*pi)',1,GC_mps); end end t = t + 1; end clear dirs dir t finished found Phase end clear ax Goal_ind % Compute the true vector length and error Error(i,1) = sqrt(sum(((Goals(i,:) - Starts(i,:)) - Decoded(i,:)).^2,2)); Vec_l(i,1) = sqrt(sum((Goals(i,:) - Starts(i,:)).^2,2)); end % Plot vector length v error data, if required if DrawFig figure subplot(2,2,1) temp = histc(Error,linspace(0,ceil(max(Error)*100)/100,100)) ./ iterations; bar(linspace(0,ceil(max(Error)*100),100),temp,'FaceColor','k','EdgeColor','k') set(gca,'FontSize',14) xlabel('Error in Decoded Translation Vector (cm)','FontSize',14) ylabel('Relative Frequency','FontSize',14) axis square subplot(2,2,2) scatter(Vec_l,Error*100,'k.') set(gca,'FontSize',14) xlabel('Decoded Translation Vector Length (m)','FontSize',14) ylabel('Error in Decoded Translation Vector (cm)','FontSize',14) hold on b2 = regress(Error*100,[Vec_l ones(size(Vec_l,1),1)]); plot(linspace(0,max(Vec_l),10),b2(2) + b2(1).*linspace(0,max(Vec_l),10),'r','LineWidth',3) hold off axis square clear b2 subplot(2,2,3) scatter([abs(Goals(:,1)-Starts(:,1)) ; abs(Goals(:,2)-Starts(:,2))],[LLA_Time(:,1) ; LLA_Time(:,2)],'k.') set(gca,'FontSize',14) xlabel('Decoded Translation Vector Length (m)','FontSize',14) ylabel('Time Taken to Decode Vector (s)','FontSize',14) axis square end clear i Distances Grid_Place_w N_grid N_place dt iterations DrawFig