% ===========================================
% train disBM.
%
% support types: (ex / id)
% label     / clamp
% man       / man
%
% weights (struct)
%       visfac      : [numvis x numfac]
%       hidfac_a    : [numhid_a x numfac]
%       hidfac_b    : [numhid_b x numfac]
%       vishid_a    : [numvis x numhid_a]
%       vishid_b    : [numvis x numhid_b]
%       visbias     : [numvis x 1]
%       hidbias_a   : [numhid_a x 1]
%       hidbias_b   : [numhid_b x 1]
%       hidlab      : [numhid_a x numlab]
%       labbias     : [numlab x 1]
%
% ===========================================

function [weights, params] = disbm_train(xtrain, ytrain_a, ytrain_b, xunlab, params, xval, yval_a, yval_b)

typein = params.typein;
type_a = params.type_a;
type_b = params.type_b;

rng('default');

% initialize jacket
if params.optgpu,
    gselect(1);
    params.optgpu = 1;
end

if strcmp(type_a, 'label') && strcmp(type_b, 'clamp'),
    [ weights, params ] = disbm_train_real_label_clamp_(xtrain, ytrain_a, ytrain_b, params, xval, yval_a);
elseif strcmp(type_a, 'man') && strcmp(type_b, 'man'),
    [ weights, params ] = disbm_train_man_man_(xtrain, ytrain_a, ytrain_b, xunlab, params, xval, yval_a, yval_b);
else
    error('Undefined training settings!');
end

end

% Label-clamp training.
function [ weights, params ] = disbm_train_real_label_clamp_(xtrain, ytrain_a, ytrain_b, params, xval, yval_a)
% file name
params.fname = sprintf('%s_%s_%s_%s_v_%d_ha_%d_hb_%d_f_%d_eps_%g_eta_%g_l2f_%g_l2r_%g_%s_pba_%g_pla_%g_pbb_%g_plb_%g_kcd_%d_nmf_%d_bs_%d_neg_%d', ...
    params.dataset, params.typein, params.type_a, params.type_b, params.numvis, params.numhid_a, params.numhid_b, params.numfac, ...
    params.eps, params.eta, params.l2reg_f, params.l2reg, params.sp_type, params.pbias_a, params.plambda_a, params.pbias_b, params.plambda_b, ...
    params.kcd, params.nmf, params.batchsize, params.negchain);

% start learning and save every 50 iterations
fname = sprintf(sprintf('%s_iter_%d_fold_%d', params.fname, params.maxiter, params.fold));
if exist(sprintf('%s/%s.mat', params.savedir, fname), 'file'),
    load(sprintf('%s/%s.mat', params.savedir, fname), 'weights');
else
    [weights, grad, params] = disbm_train_real_label_clamp(xtrain, ytrain_a, ytrain_b, params, xval, yval_a);
    save(sprintf('%s/%s.mat', params.savedir, fname), 'weights', 'grad', 'params');
end
params.fname = sprintf(sprintf('%s_iter_%d', params.fname, params.maxiter));
end

% Manifold training.
function [ weights, params ] = disbm_train_man_man_(xtrain, ytrain_a, ytrain_b, xunlab, params, xval, yval_a, yval_b)
% file name
params.fname = sprintf('%s_%s_%s_%s_v%d_ha%d_hb%d_f%d_epsa%g_epsm%g_l2r%g_pba%g_pla%g_pbb%g_plb%g_kcd%d_kmf%d_bs%d_kma%.2g_ta%.2g_kmb%.2g_tb%.2g_mt%s', ...
    params.dataset, params.typein, params.type_a, params.type_b, params.numvis, params.numhid_a, ...
    params.numhid_b, params.numfac, params.epsilon_add, params.epsilon_mult, params.l2reg, params.pbias_a, ...
    params.plambda_a, params.pbias_b, params.plambda_b, params.kcd, params.kmf, ...
    params.batchsize, params.kman_a, params.thresh_a, params.kman_b, params.thresh_b, params.mantype);

% start learning and save every 50 iterations
fname = sprintf(sprintf('%s_iter_%d_fold%d', params.fname, params.maxiter, params.fold));
if exist(sprintf('%s/%s.mat', params.savedir, fname), 'file'),
    load(sprintf('%s/%s.mat', params.savedir, fname), 'weights');
else
    [weights, grad, params] = disbm_train_man_man(xtrain, ytrain_a, ytrain_b, xunlab, params, xval, yval_a, yval_b);
    save(sprintf('%s/%s.mat', params.savedir, fname), 'weights', 'grad', 'params');
end
params.fname = fname;
end
