% =======================================
% Train disBM with manifold regularizer
% on both groups of hidden units.
%
% written by Scott Reed
%
% inputs:
% X: labeled data
% U: unlabeled data
% Ya: labels for group A
% Yb: labels for group B
% params: model training parameters
% =======================================

function [weights, grad, params] = disbm_train_man_man(X, Ya, Yb, U, params, Xval, YvalA, YvalB)

% -- initialize parameters
weights.visfac = 0.05*randn(params.numfac,params.numvis);      % 3-way
weights.hidfac_a = 0.05*randn(params.numfac,params.numhid_a);  % 3-way
weights.hidfac_b = 0.05*randn(params.numfac,params.numhid_b);  % 3-way
weights.vishid_a = 0.05*randn(params.numhid_a,params.numvis);      % 2-way
weights.vishid_b = 0.05*randn(params.numhid_b,params.numvis);  % 2-way
weights.visbias = zeros(params.numvis, 1);                                   % 1-way (bias)
weights.hidbias_a = zeros(params.numhid_a, 1);                               % 1-way (bias)
weights.hidbias_b = zeros(params.numhid_b, 1);                               % 1-way (bias)

if params.optgpu,
    weights = cpu2gpu_struct(weights);
end
grad = replicate_struct(weights, 0); % accumulate gradient

% -- filename to save
params.fname_save = sprintf('%s_fold_%d', params.fname, params.fold);
fname_mat = sprintf('%s/%s.mat', params.savedir, params.fname_save);
disp(params);

% Optimization parameters.
r1 = params.epsilon_add;
r2 = params.epsilon_mult;

nsample = params.nsample;
batchsize = params.batchsize;

init_final_momen_iter = 5;
final_momen = 0.9;
init_momen = 0.5;

small = 0.0000001;
normvis = 0.1;
normhid_a = 0.1;
normhid_b = 0.1;
normvh_a = 0.1;
normvh_b = 0.1;

% Assume that some Yb labels are missing (assigned to -1).
% In that case we can still use them, but just make a separate set for that,
% and only use correspondence w.r.t. Ya labels, which are assumed to be 
% always present.
keepidx = Yb~=-1;
X2 = X(:,~keepidx);
Ya2 = Ya(~keepidx);
Yb2 = Yb(~keepidx);
X = X(:,keepidx);
Ya = Ya(keepidx);
Yb = Yb(keepidx);

corr = makepairs_2way(Yb, Ya);
numsets = size(corr,1);
corr2 = makepairs_1way(Ya2);
numsets2 = size(corr2,1);

fprintf('\rTraining disBM %d-%d-numhid   epochs:%d r1:%f r2:%f',...
        params.numvis, params.numfac, params.maxiter, r1, r2);

for epoch = 1:params.maxiter
    fprintf('\repoch %d\n',epoch);
    
    unlab_idx = randsample(size(U,2),nsample);
    epoch_idx = randsample(numsets,nsample);
    epoch_idx2 = randsample(numsets2,nsample);
    if params.optgpu,
        Xugpu = gsingle(U(:, unlab_idx));
        X1gpu = gsingle(X(:, corr(epoch_idx,1)));
        X2gpu = gsingle(X(:, corr(epoch_idx,2)));
        X3gpu = gsingle(X(:, corr(epoch_idx,3)));

        X4gpu = gsingle(X2(:, corr2(epoch_idx2,1)));
        X5gpu = gsingle(X2(:, corr2(epoch_idx2,2)));
        X6gpu = gsingle(X2(:, corr2(epoch_idx2,3)));
    else
        Xugpu = single(U(:, unlab_idx));
        X1gpu = single(X(:, corr(epoch_idx,1)));
        X2gpu = single(X(:, corr(epoch_idx,2)));
        X3gpu = single(X(:, corr(epoch_idx,3)));

        X4gpu = single(X2(:, corr2(epoch_idx2,1)));
        X5gpu = single(X2(:, corr2(epoch_idx2,2)));
        X6gpu = single(X2(:, corr2(epoch_idx2,3)));
    end
    
    numbatch = floor(size(X1gpu,2)/batchsize);
    for batch = 1:numbatch,
        batch_idx = (1+(batch-1)*batchsize):batchsize*batch;

        Xub = Xugpu(:,batch_idx);


        X1b = X1gpu(:,batch_idx);
        X2b = X2gpu(:,batch_idx);
        X3b = X3gpu(:,batch_idx);
        X4b = X4gpu(:,batch_idx);
        X5b = X5gpu(:,batch_idx);
        X6b = X6gpu(:,batch_idx);

        if epoch > init_final_momen_iter,
            momentum = final_momen;
        else
            momentum = init_momen;
        end
        
        dvisfac = 0*weights.visfac;
        dhidfac_a = 0*weights.hidfac_a;
        dhidfac_b = 0*weights.hidfac_b;
        dvishid_a = 0*weights.vishid_a;
        dvishid_b = 0*weights.vishid_b;
        dvisbias = 0*weights.visbias;
        dhidbias_a = 0*weights.hidbias_a;
        dhidbias_b = 0*weights.hidbias_b;
        
        % Manifold.
        cost_man = 0;
        if (params.kman_a) > 0 || (params.kman_b > 0)
            % 2-way correspondence learning.
            [ grad_manifold, cost_man ] = manifold_grad(weights,X1b,X3b,X2b,params);
            dvisfac = dvisfac - grad_manifold.visfac./batchsize;
            dvishid_a = dvishid_a - grad_manifold.vishid_a./batchsize;
            dvishid_b = dvishid_b - grad_manifold.vishid_b./batchsize;
            dhidfac_a = dhidfac_a - grad_manifold.hidfac_a./batchsize;
            dhidfac_b = dhidfac_b - grad_manifold.hidfac_b./batchsize;
            dhidbias_a = dhidbias_a - grad_manifold.hidbias_a./batchsize;
            dhidbias_b = dhidbias_b - grad_manifold.hidbias_b./batchsize;

            % 1-way correspondence learning (just using emotion).
            %{
            ptmp = params;
            ptmp.kman_b = 0;
            lambda_1way = size(X2,2)/(size(X2,2)+size(X,2));
            [ grad_manifold, cost_man2 ] = manifold_grad(weights,X4b,X5b,X6b,ptmp);
            dvisfac = dvisfac - lambda_1way*grad_manifold.visfac./batchsize;
            dvishid_a = dvishid_a - lambda_1way*grad_manifold.vishid_a./batchsize;
            dvishid_b = dvishid_b - lambda_1way*grad_manifold.vishid_b./batchsize;
            dhidfac_a = dhidfac_a - lambda_1way*grad_manifold.hidfac_a./batchsize;
            dhidfac_b = dhidfac_b - lambda_1way*grad_manifold.hidfac_b./batchsize;
            dhidbias_a = dhidbias_a - lambda_1way*grad_manifold.hidbias_a./batchsize;
            dhidbias_b = dhidbias_b - lambda_1way*grad_manifold.hidbias_b./batchsize;
            %}
        end

        % Sparsity and  any other unsupervised things.
        cost_sp = 0;
        sp_a = 0;
        sp_b = 0;
        if (params.plambda_a > 0) || (params.plambda_b > 0)
            [ grad_unsup, cost_unsup ] = unsup_grad(weights,Xub,params);
            dvisfac = dvisfac - grad_unsup.visfac./batchsize;
            dvishid_a = dvishid_a - grad_unsup.vishid_a./batchsize;
            dvishid_b = dvishid_b - grad_unsup.vishid_b./batchsize;
            dhidfac_a = dhidfac_a - grad_unsup.hidfac_a./batchsize;
            dhidfac_b = dhidfac_b - grad_unsup.hidfac_b./batchsize;
            dhidbias_a = dhidbias_a - grad_unsup.hidbias_a./batchsize;
            dhidbias_b = dhidbias_b - grad_unsup.hidbias_b./batchsize;
            cost_sp = cost_unsup.sp;
            sp_a = cost_unsup.sp_a;
            sp_b = cost_unsup.sp_b;
        end

        % Update weights.
        n = size(X1b,2);
        grad.visfac = momentum*grad.visfac + r2*dvisfac - r2*params.l2reg*weights.visfac;
        weights.visfac = weights.visfac + grad.visfac;

        grad.visbias = momentum*grad.visbias + r1*dvisbias;
        weights.visbias = weights.visbias + grad.visbias;

        grad.hidfac_a = momentum*grad.hidfac_a + r2*dhidfac_a - r2*params.l2reg*weights.hidfac_a;
        weights.hidfac_a = weights.hidfac_a + grad.hidfac_a;
        grad.hidfac_b = momentum*grad.hidfac_b + r2*dhidfac_b - r2*params.l2reg*weights.hidfac_b;
        weights.hidfac_b = weights.hidfac_b + grad.hidfac_b;

        grad.hidbias_a = momentum*grad.hidbias_a + r1*dhidbias_a;
        weights.hidbias_a = weights.hidbias_a + grad.hidbias_a;
        grad.hidbias_b = momentum*grad.hidbias_b + r1*dhidbias_b;
        weights.hidbias_b = weights.hidbias_b + grad.hidbias_b;

        grad.vishid_a = momentum*grad.vishid_a + r1*dvishid_a - r1*params.l2reg*weights.vishid_a;
        weights.vishid_a = weights.vishid_a + grad.vishid_a;
        grad.vishid_b = momentum*grad.vishid_b + r1*dvishid_b - r1*params.l2reg*weights.vishid_b;
        weights.vishid_b = weights.vishid_b + grad.vishid_b;

        %%%%%%%%% NORMALIZE FILTERS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 
        %n_visfac = sqrt(sum(weights.visfac.^2,2)+small);
        n_visfac = sqrt(sum(weights.visfac.^2,1)+small);
        normvis_inc = mean(n_visfac);
        normvis = 0.95*normvis + 0.05*normvis_inc;
        weights.visfac = bsxfun(@times, weights.visfac, (normvis./n_visfac));

        n_hidfac = sqrt(sum(weights.hidfac_a.^2,2)+small);
        normhid_inc = mean(n_hidfac);
        normhid_a = 0.95*normhid_a + 0.05*normhid_inc;
        weights.hidfac_a = bsxfun(@times, weights.hidfac_a, (normhid_a./n_hidfac));

        n_hidfac = sqrt(sum(weights.hidfac_b.^2,2)+small);
        normhid_inc = mean(n_hidfac);
        normhid_b = 0.95*normhid_b + 0.05*normhid_inc;
        weights.hidfac_b = bsxfun(@times, weights.hidfac_b, (normhid_b./n_hidfac));

        n_vishid_a = sqrt(sum(weights.vishid_a.^2,2)+small);
        normvh_inc = mean(n_vishid_a);
        normvh_a = 0.95*normvh_a + 0.05*normvh_inc;
        weights.vishid_a = bsxfun(@times, weights.vishid_a, (normvh_a./n_vishid_a));

        n_vishid_b = sqrt(sum(weights.vishid_b.^2,2)+small);
        normvh_inc = mean(n_vishid_b);
        normvh_b = 0.95*normvh_b + 0.05*normvh_inc;
        weights.vishid_b = bsxfun(@times, weights.vishid_b, (normvh_b./n_vishid_b));

        %%%%%%%%%%%%%%%% END OF UPDATES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        if mod(batch,5)==0,
            fprintf(1,'batch %d of %d cost_man=%.4f,cost_sp=%.4f,sp_a=%.4f,sp_b=%.4f\n', ...
                    batch, numbatch, cost_man, cost_sp, sp_a, sp_b);
        end
    end
    if mod(epoch,50)==0,
        [weights, grad] = save_params(fname_mat, weights, grad, params, epoch);
    end
end

return;

