An oscillatory neural autoencoder based on frequency modulation and multiplexing (Soman et al 2018)

 Download zip file 
Help downloading and running models
Accession:243595
" ... We propose here an oscillatory neural network model that performs the function of an autoencoder. The model is a hybrid of rate-coded neurons and neural oscillators. Input signals modulate the frequency of the neural encoder oscillators. These signals are then multiplexed using a network of rate-code neurons that has afferent Hebbian and lateral anti-Hebbian connectivity, termed as Lateral Anti Hebbian Network (LAHN). Finally the LAHN output is de-multiplexed using an output neural layer which is a combination of adaptive Hopf and Kuramoto oscillators for the signal reconstruction. The Kuramoto-Hopf combination performing demodulation is a novel way of describing a neural phase-locked loop. The proposed model is tested using both synthetic signals and real world EEG signals. The proposed model arises out of the general motivation to construct biologically inspired, oscillatory versions of some of the standard neural network models, and presents itself as an autoencoder network based on oscillatory neurons applicable to time series signals. As a demonstration, the model is applied to compression of EEG signals."
Reference:
1 . Soman K, Muralidharan V, Chakravarthy VS (2018) An Oscillatory Neural Autoencoder Based on Frequency Modulation and Multiplexing. Front Comput Neurosci 12:52 [PubMed]
Citations  Citation Browser
Model Information (Click on a link to find other models with that property)
Model Type: Connectionist Network;
Brain Region(s)/Organism:
Cell Type(s):
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: MATLAB;
Model Concept(s): Oscillations;
Implementer(s): Soman, Karthik [karthi.soman at gmail.com];
% This code implements an oscillatory autoencoder model that accepts a 4
% dimensional simulated input and compressess to 2 dimensional signal and
% tries to reconstrcut back the original signal.
clc
clear all
close all
%% Signal generation
Fs=6000; Ts=1/Fs; 
t0 = 15;                           % signal duration
ts = Ts;                            % sampling interval
fc1 = 200;                        	% carrier frequency
fc2 = 350;
fc3 = 850;
fc4 = 1000;
clc

kf = 50;                         	% modulation index
fs = 1/ts;                       	% sampling frequency
t = [0:ts:t0-ts];                   	% time vector
df = 0.25;                          % required frequency resolution
carrier_sig1 = sin(2*pi*fc1*t);       % carrier signal
carrier_sig2 = sin(2*pi*fc2*t);
carrier_sig3 = sin(2*pi*fc3*t);
carrier_sig4 = sin(2*pi*fc4*t);
% msg = [1*ones(1,t0/(3*ts)),-2*ones(1,t0/(3*ts)),zeros(1,t0/(3*ts)+1)];
msg1=sin(2*pi*5*t)+0.5*sin(2*pi*6*t);
msg2=sin(2*pi*10*t)+0.5*sin(2*pi*14*t);
msg3=sin(2*pi*25*t)+0.5*sin(2*pi*28*t);
msg4=sin(2*pi*35*t)+0.5*sin(2*pi*40*t);
% figure;subplot(2,1,1); plot(msg1); title('Message signal1')
% subplot(2,1,2); plot(msg2); title('Message signal2')
%% FM Modulation
int_msg(1) = 0;
fm_phase_dot1=0; fm_phase_dot2=0; fm_phase_dot3=0; fm_phase_dot4=0;
fm_phase1 = fm_phase_dot1*Ts; fm_phase2 = fm_phase_dot2*Ts; fm_phase3 = fm_phase_dot3*Ts; fm_phase4 = fm_phase_dot4*Ts;
for ii = 2 : length(t)-1                  	
    fm_phase_dot1(ii) = 2*pi*fc1 + 2*pi*kf*msg1(ii); 
    fm_phase_dot2(ii) = 2*pi*fc2 + 2*pi*kf*msg2(ii);
    fm_phase_dot3(ii) = 2*pi*fc3 + 2*pi*kf*msg3(ii);
    fm_phase_dot4(ii) = 2*pi*fc4 + 2*pi*kf*msg4(ii);
    fm_phase1(ii)=fm_phase1(ii-1) + fm_phase_dot1(ii)*Ts;
    fm_phase2(ii)=fm_phase2(ii-1) + fm_phase_dot2(ii)*Ts;
    fm_phase3(ii)=fm_phase3(ii-1) + fm_phase_dot3(ii)*Ts;
    fm_phase4(ii)=fm_phase4(ii-1) + fm_phase_dot4(ii)*Ts;
end
fm_sig1=sin(fm_phase1);   % modulated signal
fm_sig2=sin(fm_phase2);   % modulated signal
fm_sig3=sin(fm_phase3);   % modulated signal
fm_sig4=sin(fm_phase4);   % modulated signal
fm_sig1(end+1)=fm_sig1(end); fm_sig2(end+1)=fm_sig2(end); fm_sig3(end+1)=fm_sig3(end); fm_sig4(end+1)=fm_sig4(end);

tstart=6000;
figure; subplot(2,2,1);plot(t(end-tstart:end),msg1(end-tstart:end));xlim([t(end-tstart) t(end)]); title('message signal1')
subplot(2,2,2);plot(t(end-tstart:end),msg2(end-tstart:end));xlim([t(end-tstart) t(end)]); title('message signal2')
subplot(2,2,3);plot(t(end-tstart:end),msg3(end-tstart:end));xlim([t(end-tstart) t(end)]); title('message signal3')
subplot(2,2,4);plot(t(end-tstart:end),msg4(end-tstart:end));xlim([t(end-tstart) t(end)]); title('message signal4')
%% MUX usking k nodes
X=[fm_sig1' fm_sig2' fm_sig3' fm_sig4'];
PI1d = X';
[N K] = size(PI1d); %N --> Dimension    K---> # of samples

%MUX using PCA
[v,lamda]=pca(X);
wt = v(:,1:2); %extract first two pcs

%MUX using network of PCA subspace learning (if using this comment above analytical PCA)
%load saved wts (comment the below line if lahn is getting trained using the foldiak_linear_fn)
% load('lahn_wts_saved_2')
% wt = T';

%train lahn (uncomment the below lines to train lahn)
% PI1d=removemean(PI1d);
% alphaa = 0.000001/K;
% betaa = 0.000001/K; 
% output_neuron_nmbr = 2;
% maxiter = 2000000;
% [T,InfoTransferRatio] = foldiak_linear_fn(PI1d, alphaa, betaa, output_neuron_nmbr, maxiter);

%wt(:,1)=wt(:,1)/norm(wt(:,1));wt(:,2)=wt(:,2)/norm(wt(:,2));
MUX(:,1) = wt(:,1)'*X'; MUX(:,2) = wt(:,2)'*X';  MUX1=MUX(:,1);MUX2=MUX(:,2);

%% FFT of the signal
L = length(MUX)*Ts;     % signal duration
t = 0:1/Fs:L-1/Fs; % Time vector
f = -(Fs-1/L)/2:1/L:(Fs-1/L)/2;  % Frequency vector

fft1 = abs((2/Fs)*fft(fm_sig1));
fft2 = abs((2/Fs)*fft(fm_sig2));
fft3 = abs((2/Fs)*fft(fm_sig3));
fft4 = abs((2/Fs)*fft(fm_sig4));
fftMUX1 = abs((2/Fs)*fft(MUX1));
fftMUX2 = abs((2/Fs)*fft(MUX2));

figure
subplot(6,1,1); plot(f,fftshift(fft1),'Linewidth',2);xlim([0 max(f)/2]);title('FFT of FM1')
subplot(6,1,2); plot(f,fftshift(fft2),'Linewidth',2);xlim([0 max(f)/2]);title('FFT of FM2')
subplot(6,1,3); plot(f,fftshift(fft3),'Linewidth',2);xlim([0 max(f)/2]);title('FFT of FM3')
subplot(6,1,4); plot(f,fftshift(fft4),'Linewidth',2);xlim([0 max(f)/2]);title('FFT of FM4')
subplot(6,1,5); plot(f,fftshift(fftMUX1),'Linewidth',2);xlim([0 max(f)/2]);title('FFT of Composite signal1');
subplot(6,1,6); plot(f,fftshift(fftMUX2),'Linewidth',2);xlim([0 max(f)/2]);title('FFT of Composite signal2');
xlabel('Frequency'); ylabel('Amplitude')

%% Demux using Adaptive Hopf dynamics
bf1 = 1*fc1-1*fc1/2;  bf2 = 1*fc2+1*fc2/2; bf3 = 1*fc3-1*fc3/8; bf4 = 1*fc4+1*fc4/4;
omega1 = 2*pi*bf1;  omega2 = 2*pi*bf2; omega3 = 2*pi*bf3;  omega4 = 2*pi*bf4;
r1=1;   r2=1;   r3=1;   r4=1;
phi1=rand;  phi2=rand;  phi3=rand;  phi4=rand;
epsl1 = 0.9;    epsl2 = 0.9;    epsl3 = 0.9;    epsl4 = 0.9;
dt=Ts;
A1=1000; A2=2000; A3=2000; A4=1500;    
% wt_from_MUX_to_Hopf = [1 0;0 1;0 1;1 0]';
% wt_from_MUX_to_Hopf = wt';
wt_from_MUX_to_Hopf = pinv(wt);
MUX_to_Hopf1 = wt_from_MUX_to_Hopf(:,1)'*MUX';MUX_to_Hopf1=MUX_to_Hopf1';
MUX_to_Hopf2 = wt_from_MUX_to_Hopf(:,2)'*MUX';MUX_to_Hopf2=MUX_to_Hopf2';
MUX_to_Hopf3 = wt_from_MUX_to_Hopf(:,3)'*MUX';MUX_to_Hopf3=MUX_to_Hopf3';
MUX_to_Hopf4 = wt_from_MUX_to_Hopf(:,4)'*MUX';MUX_to_Hopf4=MUX_to_Hopf4';
for ii=2:length(MUX)
    rdot1 = r1(ii-1)*(1-r1(ii-1)^2);
    phidot1 = omega1(ii-1) - A1*(epsl1/r1(ii-1))*MUX_to_Hopf1(ii-1)*sin(phi1(ii-1));
    omegadot1 = -A1*epsl1*MUX_to_Hopf1(ii-1)*sin(phi1(ii-1));
    r1(ii) = r1(ii-1)+rdot1*dt;
    phi1(ii) = phi1(ii-1) + phidot1*dt;
    omega1(ii) = omega1(ii-1)+omegadot1*dt;
    
    rdot2 = r2(ii-1)*(1-r2(ii-1)^2);     
    phidot2 = omega2(ii-1) - A2*(epsl2/r2(ii-1))*MUX_to_Hopf2(ii-1)*sin(phi2(ii-1));
    omegadot2 = -A2*epsl2*MUX_to_Hopf2(ii-1)*sin(phi2(ii-1));
    r2(ii) = r2(ii-1)+rdot2*dt;
    phi2(ii) = phi2(ii-1) + phidot2*dt;
    omega2(ii) = omega2(ii-1)+omegadot2*dt;
    
    rdot3 = r3(ii-1)*(1-r3(ii-1)^2);     
    phidot3 = omega3(ii-1) - A3*(epsl3/r3(ii-1))*MUX_to_Hopf3(ii-1)*sin(phi3(ii-1));
    omegadot3 = -A3*epsl3*MUX_to_Hopf3(ii-1)*sin(phi3(ii-1));
    r3(ii) = r3(ii-1)+rdot3*dt;
    phi3(ii) = phi3(ii-1) + phidot3*dt;
    omega3(ii) = omega3(ii-1)+omegadot3*dt;
    
    rdot4 = r4(ii-1)*(1-r4(ii-1)^2);     
    phidot4 = omega4(ii-1) - A4*(epsl4/r4(ii-1))*MUX_to_Hopf4(ii-1)*sin(phi4(ii-1));
    omegadot4 = -A4*epsl4*MUX_to_Hopf4(ii-1)*sin(phi4(ii-1));
    r4(ii) = r4(ii-1)+rdot4*dt;
    phi4(ii) = phi4(ii-1) + phidot4*dt;
    omega4(ii) = omega4(ii-1)+omegadot4*dt;
end
x1=r1.*cos(phi1);   x2=r2.*cos(phi2);   x3=r3.*cos(phi3);   x4=r4.*cos(phi4);
y1=r1.*sin(phi1);   y2=r2.*sin(phi2);   y3=r3.*sin(phi3);   y4=r4.*sin(phi4);
figure; plot(t,omega1/(2*pi),'Linewidth',2);  hold on; xlabel('Time'); ylabel('Frequency of Hopf oscillator')
plot(t,omega2/(2*pi),'r','Linewidth',2);  
plot(t,omega3/(2*pi),'g','Linewidth',2);  
plot(t,omega4/(2*pi),'k','Linewidth',2);  
legend('Frequency of HO1','Frequency of HO2','Frequency of HO3','Frequency of HO4')
%%
figure; subplot(2,2,1);plot(t(end-tstart:end),y1(end-tstart:end));xlim([t(end-tstart) t(end)]); title('Hopf oscillations1')
subplot(2,2,2); plot(t(end-tstart:end),y2(end-tstart:end));xlim([t(end-tstart) t(end)]); title('Hopf oscillations2')
subplot(2,2,3); plot(t(end-tstart:end),y3(end-tstart:end));xlim([t(end-tstart) t(end)]); title('Hopf oscillations3')
subplot(2,2,4); plot(t(end-tstart:end),y4(end-tstart:end));xlim([t(end-tstart) t(end)]); title('Hopf oscillations4')

fft_y1 = abs((2/Fs)*fft(y1));
fft_y2 = abs((2/Fs)*fft(y2));
fft_y3 = abs((2/Fs)*fft(y3));
fft_y4 = abs((2/Fs)*fft(y4));

figure; subplot(6,1,1);plot(f,fftshift(fftMUX1),'Linewidth',2);xlim([0 max(f)/2]);title('FFT of Composite signal1'); 
subplot(6,1,2);plot(f,fftshift(fftMUX2),'Linewidth',2);xlim([0 max(f)/2]);title('FFT of Composite signal2'); 
subplot(6,1,3); plot(f,fftshift(fft_y1));xlim([0 max(f)/2]); title('FFT of HO1') 
subplot(6,1,4);  plot(f,fftshift(fft_y2),'Linewidth',2);xlim([0 max(f)/2]); title('FFT of HO2') 
subplot(6,1,5);  plot(f,fftshift(fft_y3),'Linewidth',2);xlim([0 max(f)/2]); title('FFT of HO3') 
subplot(6,1,6);  plot(f,fftshift(fft_y4),'Linewidth',2);xlim([0 max(f)/2]); title('FFT of HO4') 
xlabel('Frequency'); ylabel('Amplitude')
%% Demodulation using PLL
VCO_out_gain = 1;
VCO_inp_gain = Fs;
LPF_Gain = 1;

LPF_out1 = 0.0*rand;    LPF_out2 = 0.0*rand;    LPF_out3 = 0.0*rand;    LPF_out4 = 0.0*rand;     
VCO_phase_dot1 = 0;     VCO_phase_dot2 = 0; VCO_phase_dot3 = 0;     VCO_phase_dot4 = 0; 
VCO_phase1 = VCO_phase_dot1*Ts; VCO_phase2 = VCO_phase_dot2*Ts; VCO_phase3 = VCO_phase_dot3*Ts; VCO_phase4 = VCO_phase_dot4*Ts;
VCO_out1 = VCO_out_gain*cos(VCO_phase1);    VCO_out2 = VCO_out_gain*cos(VCO_phase2);    VCO_out3 = VCO_out_gain*cos(VCO_phase3);    VCO_out4 = VCO_out_gain*cos(VCO_phase4);

for ii=2:length(t)-1
    VCO_phase_dot1(ii) = 2*pi*fc1 + VCO_inp_gain*LPF_out1(ii-1);
    VCO_phase1(ii) = VCO_phase1(ii-1) + VCO_phase_dot1(ii)*Ts;
    VCO_out1(ii) = VCO_out_gain*cos(VCO_phase1(ii));
    LPF_out1(ii) = LPF_Gain*sin(phi1(ii)-VCO_phase1(ii));
    
    VCO_phase_dot2(ii) = 2*pi*fc2 + VCO_inp_gain*LPF_out2(ii-1);
    VCO_phase2(ii) = VCO_phase2(ii-1) + VCO_phase_dot2(ii)*Ts;
    VCO_out2(ii) = VCO_out_gain*cos(VCO_phase2(ii));
    LPF_out2(ii) = LPF_Gain*sin(phi2(ii)-VCO_phase2(ii));
    
    VCO_phase_dot3(ii) = 2*pi*fc3 + VCO_inp_gain*LPF_out3(ii-1);
    VCO_phase3(ii) = VCO_phase3(ii-1) + VCO_phase_dot3(ii)*Ts;
    VCO_out3(ii) = VCO_out_gain*cos(VCO_phase3(ii));
    LPF_out3(ii) = LPF_Gain*sin(phi3(ii)-VCO_phase3(ii));
    
    VCO_phase_dot4(ii) = 2*pi*fc4 + VCO_inp_gain*LPF_out4(ii-1);
    VCO_phase4(ii) = VCO_phase4(ii-1) + VCO_phase_dot4(ii)*Ts;
    VCO_out4(ii) = VCO_out_gain*cos(VCO_phase4(ii));
    LPF_out4(ii) = LPF_Gain*sin(phi4(ii)-VCO_phase4(ii));
end
LPF_out1(end+1)=LPF_out1(end); LPF_out2(end+1)=LPF_out2(end);   LPF_out3(end+1)=LPF_out3(end); LPF_out4(end+1)=LPF_out4(end);
LPF_out1 = LPF_out1/max(LPF_out1);  LPF_out2 = LPF_out2/max(LPF_out2);  LPF_out3 = LPF_out3/max(LPF_out3);  LPF_out4 = LPF_out4/max(LPF_out4);

%% LPF (low pass filter) out to Leaky Integrator
niter=length(LPF_out1);
wc1=2*pi*6; wc2=2*pi*10;    wc3=2*pi*30; wc4=2*pi*40;
LIF1_HO1(1)=0; LIF2_HO1(1)=0;   
LIF1_HO2(1)=0; LIF2_HO2(1)=0;
LIF1_HO3(1)=0; LIF2_HO3(1)=0;
LIF1_HO4(1)=0; LIF2_HO4(1)=0;
dt=Ts;
for ii=2:niter
    dLIF1_HO1=wc1*(-LIF1_HO1(ii-1)+LPF_out1(ii));
    LIF1_HO1(ii)=LIF1_HO1(ii-1)+dLIF1_HO1*dt;
    dLIF1_HO2=wc2*(-LIF1_HO2(ii-1)+LPF_out2(ii));
    LIF1_HO2(ii)=LIF1_HO2(ii-1)+dLIF1_HO2*dt;
    dLIF1_HO3=wc3*(-LIF1_HO3(ii-1)+LPF_out3(ii));
    LIF1_HO3(ii)=LIF1_HO3(ii-1)+dLIF1_HO3*dt;
    dLIF1_HO4=wc4*(-LIF1_HO4(ii-1)+LPF_out4(ii));
    LIF1_HO4(ii)=LIF1_HO4(ii-1)+dLIF1_HO4*dt;
        
    dLIF2_HO1=wc1*(-LIF2_HO1(ii-1)+LIF1_HO1(ii));
    LIF2_HO1(ii)=LIF2_HO1(ii-1)+dLIF2_HO1*dt;    
    dLIF2_HO2=wc2*(-LIF2_HO2(ii-1)+LIF1_HO2(ii));
    LIF2_HO2(ii)=LIF2_HO2(ii-1)+dLIF2_HO2*dt;    
    dLIF2_HO3=wc3*(-LIF2_HO3(ii-1)+LIF1_HO3(ii));
    LIF2_HO3(ii)=LIF2_HO3(ii-1)+dLIF2_HO3*dt;    
    dLIF2_HO4=wc4*(-LIF2_HO4(ii-1)+LIF1_HO4(ii));
    LIF2_HO4(ii)=LIF2_HO4(ii-1)+dLIF2_HO4*dt;    
end

figure;subplot(2,1,1); plot(t(end-tstart:end),msg1(end-tstart:end),'Linewidth',2);xlim([t(end-tstart) t(end)]);title('Message signal1');
subplot(2,1,2); plot(t(end-tstart:end),LIF2_HO1(end-tstart:end),'Linewidth',2);xlim([t(end-tstart) t(end)]);ylim([min(LIF2_HO1(end-tstart:end))-0.03 max(LIF2_HO1(end-tstart:end))+0.03]);title('Demodulated signal1') 
figure;subplot(2,1,1); plot(t(end-tstart:end),msg2(end-tstart:end),'Linewidth',2);xlim([t(end-tstart) t(end)]);title('Message signal2');
subplot(2,1,2); plot(t(end-tstart:end),LIF2_HO2(end-tstart:end),'Linewidth',2);xlim([t(end-tstart) t(end)]);ylim([min(LIF2_HO2(end-tstart:end))-0.02 max(LIF2_HO2(end-tstart:end))+0.02]);title('Demodulated signal2') 
figure;subplot(2,1,1); plot(t(end-tstart:end),msg3(end-tstart:end),'Linewidth',2);xlim([t(end-tstart) t(end)]);title('Message signal3');
subplot(2,1,2); plot(t(end-tstart:end),LIF2_HO3(end-tstart:end),'Linewidth',2);xlim([t(end-tstart) t(end)]);ylim([min(LIF2_HO3(end-tstart:end))-0.02 max(LIF2_HO3(end-tstart:end))+0.02]);title('Demodulated signal3')
figure;subplot(2,1,1); plot(t(end-tstart:end),msg4(end-tstart:end),'Linewidth',2);xlim([t(end-tstart) t(end)]);title('Message signal4');
subplot(2,1,2); plot(t(end-tstart:end),LIF2_HO4(end-tstart:end),'Linewidth',2);xlim([t(end-tstart) t(end)]);ylim([min(LIF2_HO4(end-tstart:end))-0.02 max(LIF2_HO4(end-tstart:end))+0.02]);title('Demodulated signal4')