function [rate_func] = HMM_v4(x)
% [rate_func] = HMM_v4(x)
%
% Function `HMM_v4' returns the firing rate selected as an alternative hidden state.
% Original paper:
% Mochizuki and Shinomoto, Analog and digital codes in the brain
% https://arxiv.org/abs/1311.4035
%
% Example usage:
% rate_func = HMM_v3(x);
%
% Input argument
% x: Sample data vector.
%
% Output argument
% rate_func: ratefunction.
% 2D array stores
% 1: begining time of each bins in second
% 2: rate of each bin
% made by Yasuhiro Mochizuki
% revised by Kazuki Nakamura
% HMM_v4.m by Kazuki Nakamura 2019/02/25
% Contact: Shigeru Shinomoto: shinomoto@scphys.kyoto-u.ac.jp
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% determine the times of initiation and termination
% bin size = 5*(inter-spike interval)
x = sort(x)
onset = x(1) - 0.001 * (x(length(x)) - x(1));
offset = x(length(x)) + 0.001 * (x(length(x)) - x(1));
optw = (offset-onset)/(length(x)) * 5;
% input: sample data vector and bin size
% compute the rate in each bin
rate_func = get_hmm_ratefunc(x, optw);
% draw a graph
drawHMM(rate_func)
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% sub functions
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Function acquiring the observation sequence from a spike train
%
% arguments:
% vec_spkt: spike time measured from the initial spike
% bin_width: bin size
% returns:
% vec_Xi: observation values consisting of spike counts in each bin.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function vec_Xi = get_vec_Xi(vec_spkt, bin_width)
bin_num=ceil(vec_spkt(length(vec_spkt))/bin_width);
vec_Xi = zeros(bin_num, 1);
% counting spikes
for i=1:length(vec_spkt)
bin_id=fix(vec_spkt(i)/bin_width)+1;
if bin_id1
for n=2:num_of_obs
for i=1:num_of_states
% removed:
% sum_j=0.0;
% for j=1:num_of_states
% sum_j = sum_j + mat_alpha(n-1,j)*mat_A(j,i);
% end
sum_j = sum(mat_alpha(n-1,:)'.*mat_A(:,i))
mat_alpha(n,i) = mat_emission(n,i)*sum_j;
end
vec_C(n)=sum(mat_alpha(n,:));
mat_alpha(n,:) = mat_alpha(n,:)./vec_C(n);
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Function to get beta
%
% arguments:
% mat_A: transition matrix
% vec_pi: initial probability
% mat_emission: matrix consisting of the probability of having the observation (step,state)
% vec_C: scaling coefficient obtained when computing alpha
% returns:
% mat_beta: backward parameter
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function mat_beta = get_beta(mat_A, vec_pi, mat_emission, vec_C)
num_of_states=length(vec_pi);
num_of_obs=length(mat_emission(:, 1));
% initialize
mat_beta = zeros(num_of_obs, num_of_states);
% n=N
mat_beta(num_of_obs,:)=1.0;
% n1
for n=2:num_of_obs
% copy the seq. up to n-1
mat_hs_seq_buf=mat_hs_seq;
vec_logp_seq_buf=vec_logp_seq;
% nth node->j
for j=1:num_of_states
% n-1th node->j
% compute logp for i->j trans
for i=1:num_of_states
vec_h_logprob_i(i)=vec_logp_seq(i)+log(mat_emission(n,j)*mat_A(i,j))/log(10);
end
% get max logp
[max_element,max_pos]=max(vec_h_logprob_i);
vec_logp_seq_buf(j)=max_element;
mat_hs_seq_buf(j,:)=mat_hs_seq(max_pos,:);
mat_hs_seq_buf(j,n)=j;
end
% updata the seq.
mat_hs_seq=mat_hs_seq_buf;
vec_logp_seq=vec_logp_seq_buf;
end
[max_element, max_pos]=max(vec_logp_seq);
vec_hs_seq=mat_hs_seq(max_pos,:);
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% get_hmm_ratefunc ompute the transition between the hidden states
%
% infers optimal states using the Baum-Welch algorithm and the Viterbi algorithm
% arguments:
% spike_times
% bin_width: bin size
% returns:
% hidden states (rate_func). Here rate_func is given by a matrix of (time, state)
% 2D array stores
% 1: begining time of each bins in second
% 2: rate of each bin
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function rate_func= get_hmm_ratefunc(spike_time, bin_width)
% set the initial values of the model parameters
% vec_spkt: sets the initial spike time
% vec_Xi: acquires the observed values
% vec_Xi consists of the number of spikes (0, 1, 2, 3, ..) in each step.
EMloop_num=5000; % number of EM iteration
mat_A=[0.999 0.001; 0.001 0.999];
vec_pi=[0.5 0.5];
mean_rate=length(spike_time)/(spike_time(length(spike_time))-spike_time(1));
vec_lambda=[(mean_rate*0.75)*bin_width (mean_rate*1.25)*bin_width];
vec_spkt=zeros(length(spike_time),1);
for i=1:length(spike_time)
vec_spkt(i)=spike_time(i)-spike_time(1);
end
vec_Xi=get_vec_Xi(vec_spkt, bin_width);
% Optimizing the model parameters using the Baum-Welch algorithm
% updates parameters by hmm_E_step and hmm_M_step
[mat_Gamma, mat_Xi]=HMM_E_step(vec_Xi, mat_A, vec_lambda, vec_pi);
mat_A_old = mat_A;
vec_pi_old=vec_pi;
vec_lambda_old=vec_lambda;
% Evaluation in the while loop
% stops when the change in a model parameter becomes small
% set flag=1 if the sum of change in a parameter sumcheck becomes smaller than some threshold
% or stops if the loops are repeated so many times
loop=0;
flag=0;
while (loop<=EMloop_num && flag==0)
[vec_pi_new, vec_lambda_new, mat_A_new]=HMM_M_step(vec_Xi, mat_A, vec_lambda, vec_pi, mat_Gamma, mat_Xi);
vec_pi=vec_pi_new;
vec_lambda=vec_lambda_new;
mat_A=mat_A_new;
sum_check=0.0;
num_state=length(vec_pi);
% removed:
%for i=1:num_state
% for j=1:num_state
% sum_check=sum_check+abs(mat_A_old(i,j)-mat_A(i,j));
% end
% sum_check=sum_check+abs(vec_pi_old(i)-vec_pi(i));
% sum_check=sum_check+abs(vec_lambda_old(i)-vec_lambda(i));
%end
sum_check=sum(sum(abs(mat_A_old.-mat_A)));
sum_check=sum_check+sum(abs(vec_pi_old(i)-vec_pi(i)));
sum_check=sum_check+sum(abs(vec_lambda_old(i)-vec_lambda(i)));
if sum_check/(1.0*num_state*(num_state+2))<10^(-7)
flag=flag+1;
end
mat_A_old=mat_A;
vec_pi_old=vec_pi;
vec_lambda_old=vec_lambda;
[mat_Gamma, mat_Xi]=HMM_E_step(vec_Xi, mat_A, vec_lambda, vec_pi);
loop=loop+1;
end
% Estimate an optimal sequence of states using the Viterbi algorithm
% state is represented as 0 or 1 here
vec_hidden=HMM_Viterbi(vec_Xi, mat_A, vec_lambda, vec_pi);
% vec_hidden: 0 or 1, representing the hidden state
% two states are transformed into the rates
rate_func=zeros(length(vec_Xi),2);
onset = spike_time(1) - 0.001 * (spike_time(length(spike_time)) - spike_time(1));
c_time = onset;
for n=1:length(vec_Xi)
state_id=vec_hidden(n);
rate_func(n,1)=round(c_time*100)/100.0;
rate_func(n,2)=round(vec_lambda(state_id)*100)/(bin_width*100.0);
c_time=c_time+bin_width;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% draw a figure of the estimated state (given in a form of firing rate)
% arguments:
% spike_times : given in list or ndarray
% rate_hmm = (time, rate) determined by the Baum-Welch algorithm and the Viterbi algorithm in a form of ndarray
% returns:
% nothing, but draws a figure
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function drawHMM(rate_func)
x = rate_func(:, 1);
y = rate_func(:, 2);
ind = 1;
x_new(ind) = x(1);
y_new(ind) = min(y);
ind = ind + 1;
x_new(ind) = x(1);
y_new(ind) = y(1);
ind = ind + 1;
for i = 2 : length(y)
if y(i - 1) ~= y(i)
t = (x(i - 1) + x(i)) / 2;
x_new(ind) = t;
y_new(ind) = y(i - 1);
ind = ind + 1;
x_new(ind) = t;
y_new(ind) = y(i);
ind = ind + 1;
end
end
x_new(ind) = x(length(x));
y_new(ind) = y(length(y));
plot(x_new, y_new);
axis([min(x) max(x) 0 max(y) * 1.1]);
end