% ==============================================
% Train disBM on TFD with label (emotion)
% and clamping (identity)
%
% Objective function:
% log P(v1, v2, y1, y2) + eta*log P(y1, y2|v1, v2)
%
% written by Kihyuk Sohn
%
% inputs
% xtrain    : labeled data
% ytrain_a  : emotion label
% ytrain_b  : identity label
% params    : hyper parameters
% xval      : validation data
% yval_a    : validation emotion label
% ==============================================


function [ weights, agrad, params ] = disbm_train_real_label_clamp(xtrain, ytrain_a, ytrain_b, params, xval, yval_a)

% generate identity pairs
ypairs = makepairs(ytrain_b);
clear ytrain_b;

% initialize parameters
weights.visfac = 0.01*randn(params.numvis, params.numfac);          % 3-way
weights.hidfac_a = 0.01*randn(params.numhid_a, params.numfac);      % 3-way
weights.hidfac_b = 0.005*randn(params.numhid_b, params.numfac);     % 3-way
weights.vishid_a = 0.01*randn(params.numvis, params.numhid_a);      % 2-way
weights.vishid_b = 0.005*randn(params.numvis, params.numhid_b);     % 2-way
weights.hidlab = 0.01*randn(params.numhid_a, params.numlab);        % 2-way (softmax)
weights.visbias = zeros(params.numvis, 1);                          % 1-way (bias)
weights.hidbias_a = zeros(params.numhid_a, 1);                      % 1-way (bias)
weights.hidbias_b = zeros(params.numhid_b, 1);                      % 1-way (bias)
weights.labbias = zeros(params.numlab, 1);                          % 1-way (bias)

if params.optgpu,
    weights = cpu2gpu_struct(weights);
end
agrad = replicate_struct(weights, 0); % accumulate gradient

% filename to save
params.fname_save = sprintf('%s_iter_%d_fold_%d_date_%s', params.fname, params.maxiter, params.fold, datestr(now, 30));
fname_mat = sprintf('%s/%s.mat', params.savedir, params.fname_save);
disp(params);


% -----------------------------------
% train disBM with SAP
% -----------------------------------

batchsize_lab = round(params.batchsize/2);
negchain = round(params.negchain/2);
maxiter = params.maxiter;

% set monitoring variables
history.error = zeros(maxiter,1);
history.cls_error = zeros(maxiter, 1);
history.sparsity_a = zeros(maxiter,1);
history.sparsity_b = zeros(maxiter,1);

% initialize negative chain
negvisstate_1 = repmat(mean(xtrain, 2), [1, negchain]);
negvisstate_2 = repmat(mean(xtrain, 2), [1, negchain]);
neghidprob_a_1 = sigmoid(bsxfun(@plus, weights.vishid_a'*negvisstate_1, weights.hidbias_a));
neghidstate_a_1 = realize(neghidprob_a_1, params.optgpu);
neghidprob_a_2 = sigmoid(bsxfun(@plus, weights.vishid_a'*negvisstate_2, weights.hidbias_a));
neghidstate_a_2 = realize(neghidprob_a_2, params.optgpu);
negwhaf_1 = weights.hidfac_a'*neghidstate_a_1;
negwhaf_2 = weights.hidfac_a'*neghidstate_a_2;
neghidprob_b_shared = sigmoid(bsxfun(@plus, weights.vishid_b'*negvisstate_1 + weights.vishid_b'*negvisstate_2, weights.hidbias_b));
neghidstate_b_shared = realize(neghidprob_b_shared, params.optgpu);
negwhbf_shared = weights.hidfac_b'*neghidstate_b_shared;

nsample_pairs = size(ypairs, 1);
numbatch = min(floor(nsample_pairs/(2*batchsize_lab)), 100);

% running average for sparsity
runavg_hid_a = zeros(params.numhid_a, 1);
runavg_hid_b = zeros(params.numhid_b, 1);

for t = 1:maxiter,
    % parameters for sgd
    if t > params.momentum_change,
        momentum = params.momentum_final;
    else
        momentum = params.momentum_init;
    end
    kcd = min(1 + floor(t/10), params.kcd);
    eps = params.eps/(1+params.eps_decay*t);
    
    % randomly shuffle data
    randidx_pairs = randperm(nsample_pairs);
    
    % monitoring variables for epoch
    recon_err_epoch = zeros(numbatch, 1);
    cls_err_epoch = zeros(numbatch, 1);
    sparsity_a_epoch = zeros(numbatch, 1);
    sparsity_b_epoch = zeros(numbatch, 1);
    
    tS = tic;
    for b = 1:numbatch,
        % initialize gradient for minibatch
        % positive (generative, discriminative) and negative
        pos = replicate_struct(weights, 0);
        neg = replicate_struct(weights, 0);
        grad_disc = replicate_struct(weights, 0);
        
        
        % ------------------------------------
        % identity-paired data (pos, gen.)
        % ------------------------------------
        
        % load data
        batchidx = randidx_pairs((b-1)*batchsize_lab+1:b*batchsize_lab);
        data_1 = xtrain(:, ypairs(batchidx, 1));
        data_2 = xtrain(:, ypairs(batchidx, 2));
        label_1 = ytrain_a(ypairs(batchidx, 1));
        label_2 = ytrain_a(ypairs(batchidx, 2));
        poslabprob_1 = multi_output(label_1, params.numlab);
        poslabprob_2 = multi_output(label_2, params.numlab);
        if params.optgpu,
            data_1 = gsingle(data_1);
            data_2 = gsingle(data_2);
            poslabprob_1 = gsingle(poslabprob_1);
            poslabprob_2 = gsingle(poslabprob_2);
        end
        
        hbiasmat_a = repmat(weights.hidbias_a, [1 batchsize_lab]);
        hbiasmat_b = repmat(weights.hidbias_b, [1 batchsize_lab]);
        
        % inference with clamping: P(ha1, ha2, hb|v1, v2, y1, y2)
        wvha_v_1 = weights.vishid_a'*data_1;
        wvha_v_2 = weights.vishid_a'*data_2;
        wvhb_v_1 = weights.vishid_b'*data_1;
        wvhb_v_2 = weights.vishid_b'*data_2;
        poswvf_1 = weights.visfac'*data_1;
        poswvf_2 = weights.visfac'*data_2;
        whl_l_1 = weights.hidlab*poslabprob_1;
        whl_l_2 = weights.hidlab*poslabprob_2;
        
        % initialize with 2-way connection
        % followed by mean-field update
        poshidprob_b_shared = sigmoid(wvhb_v_1 + wvhb_v_2 + hbiasmat_b);
        poswhbf_shared = weights.hidfac_b'*poshidprob_b_shared;
        
        for i = 1:params.nmf,
            % hidden
            poshidprob_a_1 = sigmoid(weights.hidfac_a*(poswvf_1.*poswhbf_shared) + wvha_v_1 + whl_l_1 + hbiasmat_a);
            poswhaf_1 = weights.hidfac_a'*poshidprob_a_1;
            
            poshidprob_a_2 = sigmoid(weights.hidfac_a*(poswvf_2.*poswhbf_shared) + wvha_v_2 + whl_l_2 + hbiasmat_a);
            poswhaf_2 = weights.hidfac_a'*poshidprob_a_2;
            
            % share hidden unit with same ids
            poshidprob_b_shared = sigmoid(weights.hidfac_b*(poswvf_1.*poswhaf_1 + poswvf_2.*poswhaf_2) + wvhb_v_1 + wvhb_v_2 + hbiasmat_b);
            poswhbf_shared = weights.hidfac_b'*poshidprob_b_shared;
        end
        
        % compute gradient
        pos.visfac = pos.visfac + data_1*(poswhaf_1.*poswhbf_shared)' + data_2*(poswhaf_2.*poswhbf_shared)';
        pos.hidfac_a = pos.hidfac_a + poshidprob_a_1*(poswvf_1.*poswhbf_shared)' + poshidprob_a_2*(poswvf_2.*poswhbf_shared)';
        pos.hidfac_b = pos.hidfac_b + poshidprob_b_shared*(poswvf_1.*poswhaf_1 + poswvf_2.*poswhaf_2)';
        pos.vishid_a = pos.vishid_a + data_1*poshidprob_a_1' + data_2*poshidprob_a_2';
        pos.vishid_b = pos.vishid_b + (data_1 + data_2)*poshidprob_b_shared';
        pos.hidlab = pos.hidlab + poshidprob_a_1*poslabprob_1' + poshidprob_a_2*poslabprob_2';
        pos.visbias = pos.visbias + sum(data_1, 2) + sum(data_2, 2);
        pos.hidbias_a = pos.hidbias_a + sum(poshidprob_a_1, 2) + sum(poshidprob_a_2, 2);
        pos.hidbias_b = pos.hidbias_b + 2*sum(poshidprob_b_shared, 2);
        pos.labbias = pos.labbias + sum(poslabprob_1, 2) + sum(poslabprob_2, 2);
        
        
        % monitoring variables (reconstruction, sparsity)
        wvh_ha_1 = weights.vishid_a*poshidprob_a_1;
        wvh_ha_2 = weights.vishid_a*poshidprob_a_2;
        wvh_hb_shared = weights.vishid_b*poshidprob_b_shared;
        vbiasmat = repmat(weights.visbias, [1 batchsize_lab]);
        
        recon_1 = weights.visfac*(poswhaf_1.*poswhbf_shared) + wvh_ha_1 + wvh_hb_shared + vbiasmat;
        recon_2 = weights.visfac*(poswhaf_2.*poswhbf_shared) + wvh_ha_2 + wvh_hb_shared + vbiasmat;
        recon_err = sum(sum((recon_1 - data_1).^2)) + sum(sum((recon_2 - data_2).^2));
        poshidact_a = sum(poshidprob_a_1, 2) + sum(poshidprob_a_2, 2);
        poshidact_b = 2*sum(poshidprob_b_shared, 2);
        
        recon_err_epoch(b) = double(recon_err)/(2*batchsize_lab);
        sparsity_a_epoch(b) = mean(poshidact_a)/(2*batchsize_lab);
        sparsity_b_epoch(b) = mean(poshidact_b)/(2*batchsize_lab);
        
        
        % ------------------------------------
        % identity-paired data (neg, disc.)
        % ------------------------------------
        
        % positive gradient
        grad_disc.visfac = grad_disc.visfac + pos.visfac;
        grad_disc.hidfac_a = grad_disc.hidfac_a + pos.hidfac_a;
        grad_disc.hidfac_b = grad_disc.hidfac_b + pos.hidfac_b;
        grad_disc.vishid_a = grad_disc.vishid_a + pos.vishid_a;
        grad_disc.vishid_b = grad_disc.vishid_b + pos.vishid_b;
        grad_disc.hidlab = grad_disc.hidlab + pos.hidlab;
        grad_disc.visbias = grad_disc.visbias + pos.visbias;
        grad_disc.hidbias_a = grad_disc.hidbias_a + pos.hidbias_a;
        grad_disc.hidbias_b = grad_disc.hidbias_b + pos.hidbias_b;
        grad_disc.labbias = grad_disc.labbias + pos.labbias;
        
        % negative phase
        hbiasmat_a = repmat(weights.hidbias_a, [1 batchsize_lab]);
        hbiasmat_b = repmat(weights.hidbias_b, [1 batchsize_lab]);
        lbiasmat = repmat(weights.labbias, [1 batchsize_lab]);
        
        % inference (negative): P(ha1, ha2, hb, y1, y2|v1, v2)
        poslabprob_1 = lbiasmat;
        poslabprob_2 = lbiasmat;
        whl_l_1 = weights.hidlab*poslabprob_1;
        whl_l_2 = weights.hidlab*poslabprob_2;
        
        % initialize with 2-way connection
        % followed by mean-field update
        poshidprob_b_shared = sigmoid(wvhb_v_1 + wvhb_v_2 + hbiasmat_b);
        poshidstate_b_shared = realize(poshidprob_b_shared, params.optgpu);
        poswhbf_shared = weights.hidfac_b'*poshidstate_b_shared;
        
        for i = 1:params.nmf,
            % hidden
            poshidprob_a_1 = sigmoid(weights.hidfac_a*(poswvf_1.*poswhbf_shared) + wvha_v_1 + whl_l_1 + hbiasmat_a);
            poshidstate_a_1 = realize(poshidprob_a_1, params.optgpu);
            poswhaf_1 = weights.hidfac_a'*poshidstate_a_1;
            
            poshidprob_a_2 = sigmoid(weights.hidfac_a*(poswvf_2.*poswhbf_shared) + wvha_v_2 + whl_l_2 + hbiasmat_a);
            poshidstate_a_2 = realize(poshidprob_a_2, params.optgpu);
            poswhaf_2 = weights.hidfac_a'*poshidstate_a_2;
            
            % label
            poslabprob_1 = weights.hidlab'*poshidstate_a_1 + lbiasmat;
            poslabprob_1 = exp(bsxfun(@minus, poslabprob_1, max(poslabprob_1, [], 1)));
            poslabprob_1 = bsxfun(@rdivide, poslabprob_1, sum(poslabprob_1, 1));
            poslabstate_1 = realize_mult(poslabprob_1, params.optgpu);
            whl_l_1 = weights.hidlab*poslabstate_1;
            
            poslabprob_2 = weights.hidlab'*poshidstate_a_2 + lbiasmat;
            poslabprob_2 = exp(bsxfun(@minus, poslabprob_2, max(poslabprob_2, [], 1)));
            poslabprob_2 = bsxfun(@rdivide, poslabprob_2, sum(poslabprob_2, 1));
            poslabstate_2 = realize_mult(poslabprob_2, params.optgpu);
            whl_l_2 = weights.hidlab*poslabstate_2;
            
            % share hidden unit with same ids
            poshidprob_b_shared = sigmoid(weights.hidfac_b*(poswvf_1.*poswhaf_1 + poswvf_2.*poswhaf_2) + wvhb_v_1 + wvhb_v_2 + hbiasmat_b);
            poshidstate_b_shared = realize(poshidprob_b_shared, params.optgpu);
            poswhbf_shared = weights.hidfac_b'*poshidstate_b_shared;
        end
        
        % compute gradient
        grad_disc.visfac = grad_disc.visfac - (data_1*(poswhaf_1.*poswhbf_shared)' + data_2*(poswhaf_2.*poswhbf_shared)');
        grad_disc.hidfac_a = grad_disc.hidfac_a - (poshidprob_a_1*(poswvf_1.*poswhbf_shared)' + poshidprob_a_2*(poswvf_2.*poswhbf_shared)');
        grad_disc.hidfac_b = grad_disc.hidfac_b - poshidprob_b_shared*(poswvf_1.*poswhaf_1 + poswvf_2.*poswhaf_2)';
        grad_disc.vishid_a = grad_disc.vishid_a - (data_1*poshidprob_a_1' + data_2*poshidprob_a_2');
        grad_disc.vishid_b = grad_disc.vishid_b - (data_1 + data_2)*poshidprob_b_shared';
        grad_disc.hidlab = grad_disc.hidlab - (poshidprob_a_1*poslabprob_1' + poshidprob_a_2*poslabprob_2');
        grad_disc.visbias = grad_disc.visbias - (sum(data_1, 2) + sum(data_2, 2));
        grad_disc.hidbias_a = grad_disc.hidbias_a - (sum(poshidprob_a_1, 2) + sum(poshidprob_a_2, 2));
        grad_disc.hidbias_b = grad_disc.hidbias_b - 2*sum(poshidprob_b_shared, 2);
        grad_disc.labbias = grad_disc.labbias - (sum(poslabprob_1, 2) + sum(poslabprob_2, 2));
        
        % monitoring variables (emotion recognition)
        [~, pred_1] = max(poslabprob_1, [], 1);
        [~, pred_2] = max(poslabprob_2, [], 1);
        cls_err = (mean(pred_1(:) ~= label_1(:)) + mean(pred_2(:) ~= label_2(:)))/2;
        cls_err_epoch(b) = double(cls_err);
        
        
        % ------------------------------------
        % negative phase
        % ------------------------------------
        
        vbiasmat = repmat(weights.visbias, [1 negchain]);
        hbiasmat_a = repmat(weights.hidbias_a, [1 negchain]);
        hbiasmat_b = repmat(weights.hidbias_b, [1 negchain]);
        lbiasmat = repmat(weights.labbias, [1 negchain]);
        
        for i = 1:kcd,
            % visible
            wvh_ha_1 = weights.vishid_a*neghidstate_a_1;
            wvh_ha_2 = weights.vishid_a*neghidstate_a_2;
            wvh_hb_shared = weights.vishid_b*neghidstate_b_shared;
            
            negvisprob_1 = weights.visfac*(negwhaf_1.*negwhbf_shared) + wvh_ha_1 + wvh_hb_shared + vbiasmat;
            negvisstate_1 = negvisprob_1;
            negwvf_1 = weights.visfac'*negvisstate_1;
            
            negvisprob_2 = weights.visfac*(negwhaf_2.*negwhbf_shared) + wvh_ha_2 + wvh_hb_shared + vbiasmat;
            negvisstate_2 = negvisprob_2;
            negwvf_2 = weights.visfac'*negvisstate_2;
            
            % label
            neglabprob_1 = weights.hidlab'*neghidstate_a_1 + lbiasmat;
            neglabprob_1 = exp(bsxfun(@minus, neglabprob_1, max(neglabprob_1, [], 1)));
            neglabprob_1 = bsxfun(@rdivide, neglabprob_1, sum(neglabprob_1, 1));
            neglabstate_1 = realize_mult(neglabprob_1, params.optgpu);
            
            neglabprob_2 = weights.hidlab'*neghidstate_a_2 + lbiasmat;
            neglabprob_2 = exp(bsxfun(@minus, neglabprob_2, max(neglabprob_2, [], 1)));
            neglabprob_2 = bsxfun(@rdivide, neglabprob_2, sum(neglabprob_2, 1));
            neglabstate_2 = realize_mult(neglabprob_2, params.optgpu);
            
            % hidden
            wvha_v_1 = weights.vishid_a'*negvisstate_1;
            wvhb_v_1 = weights.vishid_b'*negvisstate_1;
            wvha_v_2 = weights.vishid_a'*negvisstate_2;
            wvhb_v_2 = weights.vishid_b'*negvisstate_2;
            whl_l_1 = weights.hidlab*neglabstate_1;
            whl_l_2 = weights.hidlab*neglabstate_2;
            
            neghidprob_a_1 = sigmoid(weights.hidfac_a*(negwvf_1.*negwhbf_shared) + wvha_v_1 + whl_l_1 + hbiasmat_a);
            neghidstate_a_1 = realize(neghidprob_a_1, params.optgpu);
            negwhaf_1 = weights.hidfac_a'*neghidstate_a_1;
            
            neghidprob_a_2 = sigmoid(weights.hidfac_a*(negwvf_2.*negwhbf_shared) + wvha_v_2 + whl_l_2 + hbiasmat_a);
            neghidstate_a_2 = realize(neghidprob_a_2, params.optgpu);
            negwhaf_2 = weights.hidfac_a'*neghidstate_a_2;
            
            neghidprob_b_shared = sigmoid(weights.hidfac_b*(negwvf_1.*negwhaf_1 + negwvf_2.*negwhaf_2) + wvhb_v_1 + wvhb_v_2 + hbiasmat_b);
            neghidstate_b_shared = realize(neghidprob_b_shared, params.optgpu);
            negwhbf_shared = weights.hidfac_b'*neghidstate_b_shared;
        end
        
        % compute gradient
        neg.visfac = negvisstate_1*(negwhaf_1.*negwhbf_shared)' + negvisstate_2*(negwhaf_2.*negwhbf_shared)';
        neg.hidfac_a = neghidstate_a_1*(negwvf_1.*negwhbf_shared)' + neghidstate_a_2*(negwvf_2.*negwhbf_shared)';
        neg.hidfac_b = neghidstate_b_shared*(negwvf_1.*negwhaf_1)' + neghidstate_b_shared*(negwvf_2.*negwhaf_2)';
        neg.vishid_a = negvisstate_1*neghidstate_a_1' + negvisstate_2*neghidstate_a_2';
        neg.vishid_b = negvisstate_1*neghidstate_b_shared' + negvisstate_2*neghidstate_b_shared';
        neg.hidlab = neghidstate_a_1*neglabstate_1' + neghidstate_a_2*neglabstate_2';
        neg.visbias = sum(negvisstate_1, 2) + sum(negvisstate_2, 2);
        neg.hidbias_a = sum(neghidstate_a_1, 2) + sum(neghidstate_a_2, 2);
        neg.hidbias_b = 2*sum(neghidstate_b_shared, 2);
        neg.labbias = sum(neglabstate_1, 2) + sum(neglabstate_2, 2);
        
        
        % ------------------------------------
        % reweight and add gradients
        % ------------------------------------
        
        pos = replicate_struct(pos, 1/(2*batchsize_lab));
        grad_disc = replicate_struct(grad_disc, params.eta/(2*batchsize_lab));
        neg = replicate_struct(neg, 1/(2*negchain));
        
        % add gradients
        pos.visfac = pos.visfac + grad_disc.visfac;
        pos.hidfac_a = pos.hidfac_a + grad_disc.hidfac_a;
        pos.hidfac_b = pos.hidfac_b + grad_disc.hidfac_b;
        pos.vishid_a = pos.vishid_a + grad_disc.vishid_a;
        pos.vishid_b = pos.vishid_b + grad_disc.vishid_b;
        pos.hidlab = pos.hidlab + grad_disc.hidlab;
        pos.visbias = pos.visbias + grad_disc.visbias;
        pos.hidbias_a = pos.hidbias_a + grad_disc.hidbias_a;
        pos.hidbias_b = pos.hidbias_b + grad_disc.hidbias_b;
        pos.labbias = pos.labbias + grad_disc.labbias;
        
        
        % ------------------------------------
        % regularizations
        % (l2 weight decay, sparsity)
        % ------------------------------------
        
        % l2 weight decay
        pos.visfac = pos.visfac - params.l2reg_f*weights.visfac;
        pos.hidfac_a = pos.hidfac_a - params.l2reg_f*weights.hidfac_a;
        pos.hidfac_b = pos.hidfac_b - params.l2reg_f*weights.hidfac_b;
        pos.vishid_a = pos.vishid_a - params.l2reg*weights.vishid_a;
        pos.vishid_b = pos.vishid_b - params.l2reg*weights.vishid_b;
        pos.hidlab = pos.hidlab - params.l2reg*weights.hidlab;
        
        % sparsity
        runavg_hid_a = params.sp_damp*runavg_hid_a + (1-params.sp_damp)*poshidact_a/(2*batchsize_lab);
        pos.hidbias_a = pos.hidbias_a - params.plambda_a*(runavg_hid_a - params.pbias_a);
        
        runavg_hid_b = params.sp_damp*runavg_hid_b + (1-params.sp_damp)*poshidact_b/(2*batchsize_lab);
        pos.hidbias_b = pos.hidbias_b - params.plambda_b*(runavg_hid_b - params.pbias_b);
        
        % update parameters
        [weights, agrad] = update_params(weights, agrad, pos, neg, momentum, eps, params.usepcd);
        
        % normalize factor weights if greater than threshold
        max_l2 = 1;
        
        n_visfac = sqrt(sum(weights.visfac.^2, 1) + 1e-6);
        n_visfac = max(max_l2, n_visfac);
        weights.visfac = bsxfun(@times, weights.visfac, max_l2./n_visfac);
        
        n_hidfac_a = sqrt(sum(weights.hidfac_a.^2, 1) + 1e-6);
        n_hidfac_a = max(max_l2, n_hidfac_a);
        weights.hidfac_a = bsxfun(@times, weights.hidfac_a, max_l2./n_hidfac_a);
        
        n_hidfac_b = sqrt(sum(weights.hidfac_b.^2, 1) + 1e-6);
        n_hidfac_b = max(max_l2, n_hidfac_b);
        weights.hidfac_b = bsxfun(@times, weights.hidfac_b, max_l2./n_hidfac_b);
    end
    
    history.error(t) = double(sum(recon_err_epoch));
    history.cls_error(t) = 100*double(mean(cls_err_epoch));
    history.sparsity_a(t) = double(mean(sparsity_a_epoch));
    history.sparsity_b(t) = double(mean(sparsity_b_epoch));
    
    tE = toc(tS);
    if params.verbose,
        fprintf('epoch %d:\t recon err = %.f\t cls err = %g\t sparsity (a) = %g\t sparsity (b) = %g, norm (3way) = %.3f (time = %g)\n', ...
            t, history.error(t), history.cls_error(t), history.sparsity_a(t), history.sparsity_b(t), ...
            (mean(sqrt(sum(weights.visfac.^2, 1))) + mean(sqrt(sum(weights.hidfac_a.^2, 1))) + ...
            mean(sqrt(sum(weights.hidfac_b.^2, 1))))/3, tE);
        if exist('xval', 'var') && ~isempty(xval),
            [err_val, ll] = evaluate_label_disbm(xval, yval_a, weights, 200);
            fprintf('epoch %d:\t val err = %g, ll = %g\n', t, 100*err_val, ll);
        end
    end
    
    if mod(t, params.saveiter) == 0,
        fprintf('epoch %d:\t recon err = %.f\t cls err = %g\t sparsity (a) = %g\t sparsity (b) = %g, norm (3way) = %.3f\n', ...
            t, history.error(t), history.cls_error(t), history.sparsity_a(t), history.sparsity_b(t), ...
            (mean(sqrt(sum(weights.visfac.^2, 1))) + mean(sqrt(sum(weights.hidfac_a.^2, 1))) + ...
            mean(sqrt(sum(weights.hidfac_b.^2, 1))))/3);
        
        % save parameters
        fname_mat_iter = sprintf('%s/%s_iter_%d_fold_%d.mat', params.savedir, params.fname, t, params.fold);
        save_params(fname_mat_iter, weights, agrad, params, t, history);
        fprintf('%s\n', fname_mat_iter);
    end
end

[weights, agrad] = save_params(fname_mat, weights, agrad, params, t, history);

return;
