% function [B_matrix, pi_vector, total_obs_distribution] = trainBandPiMatricesSpringer(state_observation_values)
%
% Train the B matrix and pi vector for the Springer HMM.
% The pi vector is the initial state probability, while the B matrix are
% the observation probabilities. In the case of Springer's algorith, the
% observation probabilities are based on a logistic regression-based
% probabilities.
%
%% Inputs:
% state_observation_values: an Nx4 cell array of observation values from
% each of N PCG signals for each (of 4) state. Within each cell is a KxJ
% double array, where K is the number of samples from that state in the PCG
% and J is the number of feature vectors extracted from the PCG.
%
%% Outputs:
% The B_matrix and pi arrays for an HMM - as Springer et al's algorithm is a
% duration dependant HMM, there is no need to calculate the A_matrix, as
% the transition between states is only dependant on the state durations.
% total_obs_distribution:
%
% Developed by David Springer for the paper:
% D. Springer et al., "Logistic Regression-HSMM-based Heart Sound
% Segmentation," IEEE Trans. Biomed. Eng., In Press, 2015.
%
%% Copyright (C) 2016 David Springer
% dave.springer@gmail.com
%
% This program is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with this program. If not, see .
function [B_matrix, pi_vector, total_obs_distribution] = trainBandPiMatricesSpringer(state_observation_values)
%% Prelim
number_of_states = 4;
%% Set pi_vector
% The true value of the pi vector, which are the initial state
% probabilities, are dependant on the heart rate of each PCG, and the
% individual sound duration for each patient. Therefore, instead of setting
% a patient-dependant pi_vector, simplify by setting all states as equally
% probable:
pi_vector = [0.25,0.25,0.25,0.25];
%% Train the logistic regression-based B_matrix:
% Initialise the B_matrix as a 1x4 cell array. This is to hold the
% coefficients of the trained logisitic regression model for each state.
B_matrix = cell(1,number_of_states);
statei_values = cell(number_of_states,1);
for PCGi = 1: length(state_observation_values)
statei_values{1} = vertcat(statei_values{1},state_observation_values{PCGi,1});
statei_values{2} = vertcat(statei_values{2},state_observation_values{PCGi,2});
statei_values{3} = vertcat(statei_values{3},state_observation_values{PCGi,3});
statei_values{4} = vertcat(statei_values{4},state_observation_values{PCGi,4});
end
% In order to use Bayes' formula with the logistic regression derived
% probabilities, we need to get the probability of seeing a specific
% observation in the total training data set. This is the
% 'total_observation_sequence', and the mean and covariance for each state
% is found:
total_observation_sequence = vertcat(statei_values{1}, statei_values{2}, statei_values{3}, statei_values{4});
total_obs_distribution = cell(2,1);
total_obs_distribution{1} = mean(total_observation_sequence);
total_obs_distribution{2} = cov(total_observation_sequence);
for state = 1: number_of_states
% Randomly select indices of samples from the other states not being
% learnt, in order to balance the two data sets. The code below ensures
% that if class 1 is being learnt vs the rest, the number of the rest =
% the number of class 1, evenly split across all other classes
length_of_state_samples = length(statei_values{state});
% Number of samples required from each of the other states:
length_per_other_state = floor(length_of_state_samples/(number_of_states-1));
%If the length of the main class / (num states - 1) >
%length(shortest other class), then only select
%length(shortect other class) from the other states,
%and (3* length) for main class
min_length_other_class = inf;
for other_state = 1: number_of_states
samples_in_other_state = length(statei_values{other_state});
if(other_state == state)
else
min_length_other_class = min([min_length_other_class, samples_in_other_state]);
end
end
%This means there aren't enough samples in one of the
%states to match the length of the main class being
%trained:
if( length_per_other_state > min_length_other_class)
length_per_other_state = min_length_other_class;
end
training_data = cell(2,1);
for other_state = 1: number_of_states
samples_in_other_state = length(statei_values{other_state});
if(other_state == state)
%Make sure you only choose (n-1)*3 *
%length_per_other_state samples for the main
%state, to ensure that the sets are balanced:
indices = randperm(samples_in_other_state,length_per_other_state*(number_of_states-1));
training_data{1} = statei_values{other_state}(indices,:);
else
indices = randperm(samples_in_other_state,length_per_other_state);
state_data = statei_values{other_state}(indices,:);
training_data{2} = vertcat(training_data{2}, state_data);
end
end
% Label all the data:
labels = ones(length(training_data{1}) + length(training_data{2}),1);
labels(1:length(training_data{1})) = 2;
% Train the logisitic regression model for this state:
all_data = [training_data{1};training_data{2}];
[B,~,~] = mnrfit(all_data,labels);
B_matrix{state} = B;
end