% ===========================================
% demo code for disentangling Boltzmann
% machine (disBM) on TFD
% ===========================================

function [weights, params] = run_disbm_tfd(params, foldlist)

if ~exist('foldlist', 'var'),
    foldlist = 1:5;
end

% ----------------------------
% load tfd data
% ----------------------------

disp('Loading data...');
[xlab, ylab_ex, ylab_id, folds, folds_ex, folds_id, xunlab] = load_tfd_omp;
disp('Done.');

% -- preprocessing
% subtract mean and divide by stds
m = mean(xunlab, 2);
stds = sqrt(var(xunlab, [], 2) + 0.01);
xunlab = bsxfun(@rdivide, bsxfun(@minus, xunlab, m), stds);
xlab = bsxfun(@rdivide, bsxfun(@minus, xlab, m), stds);

% ----------------------------
% hyperparameters
% ----------------------------

params = add_defaults(params);
params.dataset = 'tfd_omp';
params.numlab = sum(unique(ylab_ex) ~= -1);
params.numvis = size(xlab, 1);
optgpu = params.optgpu;
type_ex = params.type_a;
type_id = params.type_b;

% ----------------------------
% learning and evaluation
% ----------------------------

Clist = [0.003, 0.01, 0.03, 0.1, 0.3, 1, 10, 30, 100, 300, 1000];

acc_val_list_ex = zeros(length(foldlist), length(Clist));
acc_val_list_id = zeros(length(foldlist), length(Clist));
acc_ts_list_ex = zeros(length(foldlist), length(Clist));
acc_ts_list_id = zeros(length(foldlist), length(Clist));
auc_val_list_ex = zeros(length(foldlist), 1);
auc_val_list_id = zeros(length(foldlist), 1);
auc_ts_list_ex = zeros(length(foldlist), 1);
auc_ts_list_id = zeros(length(foldlist), 1);

for fold_id = foldlist,
    % split data into train and validation folds
    tr_id = folds(:, fold_id) == 1;
    val_id = folds(:, fold_id) == 2;
    
    xtrain = xlab(:, tr_id);
    ytrain_ex = ylab_ex(tr_id);
    ytrain_id = ylab_id(tr_id);
    xval = xlab(:, val_id);
    yval_ex = ylab_ex(val_id);
    yval_id = ylab_id(val_id);

    % ----------------------------
    % train disBM
    % ----------------------------
    params.fold = fold_id;
    params.optgpu = optgpu;
    
    % learning
    [weights, params] = disbm_train(xtrain, ytrain_ex, ytrain_id, xunlab, params, xval, yval_ex, yval_id);
    infer = disbm_infer(weights, params);
    
    % infer latent factors of variation.
    [hlab_ex, hlab_id] = infer(xlab);
    
    % ----------------------------
    % evaluate classification
    % - cross validation
    % ----------------------------
    
    tr_id = folds_ex(:, fold_id) == 1;
    val_id = folds_ex(:, fold_id) == 2;
    ts_id = folds_ex(:, fold_id) == 3;
    
    ytrain_ex = ylab_ex(tr_id);
    yval_ex = ylab_ex(val_id);
    ytest_ex = ylab_ex(ts_id);
    
    % use expression units
    htrain_ex = hlab_ex(:, tr_id);
    hval_ex = hlab_ex(:, val_id);
    htest_ex = hlab_ex(:, ts_id);
    
    [acc_val, acc_ts] = eval_cls(Clist, htrain_ex, ytrain_ex, hval_ex, yval_ex, htest_ex, ytest_ex);
    acc_val_list_ex(fold_id, :) = acc_val;
    acc_ts_list_ex(fold_id, :) = acc_ts;
    [~, id] = max(acc_val);
    acc_ts_ex = acc_ts_list_ex(fold_id, id);
    
    % use identity units
    htrain_id = hlab_id(:, tr_id);
    hval_id = hlab_id(:, val_id);
    htest_id = hlab_id(:, ts_id);
    
    [acc_val, acc_ts] = eval_cls(Clist, htrain_id, ytrain_ex, hval_id, yval_ex, htest_id, ytest_ex);
    acc_val_list_id(fold_id, :) = acc_val;
    acc_ts_list_id(fold_id, :) = acc_ts;
    [~, id] = max(acc_val);
    acc_ts_id = acc_ts_list_id(fold_id, id);
    
    % ----------------------------
    % evaluate verification
    % ----------------------------
    
    val_id = folds_id(:, fold_id) == 2;
    ts_id = folds_id(:, fold_id) == 3;
    
    yval_id = ylab_id(val_id);
    ytest_id = ylab_id(ts_id);
    
    rng('default');
    pairs_val = makepairs_for_verification(yval_id);
    pairs_ts = makepairs_for_verification(ytest_id);
    
    % use expression units
    hval_ex = hlab_ex(:, val_id);
    htest_ex = hlab_ex(:, ts_id);
    
    auc_val_list_ex(fold_id) = eval_vrf(pairs_val, hval_ex);
    auc_ts_list_ex(fold_id) = eval_vrf(pairs_ts, htest_ex);
    
    % use identity units
    hval_id = hlab_id(:, val_id);
    htest_id = hlab_id(:, ts_id);
    
    auc_val_list_id(fold_id) = eval_vrf(pairs_val, hval_id);
    auc_ts_list_id(fold_id) = eval_vrf(pairs_ts, htest_id);
    
    % log for this fold.
    % only when we train model on each fold separately
    if length(foldlist) < 5,
        fid = fopen(sprintf('log/%s_%s_%s_cls_fold%d.txt', params.dataset, type_ex, type_id, fold_id), 'a+');
        fprintf(fid, 'feat: expr, acc_val = %g, acc_test = %g (%s)\n', max(acc_val_list_ex(fold_id, :)), acc_ts_ex, params.fname);
        fprintf(fid, 'feat: id, acc_val = %g, acc_test = %g (%s)\n\n', max(acc_val_list_id(fold_id, :)), acc_ts_id, params.fname);
        fclose(fid);
        
        fid = fopen(sprintf('log/%s_%s_%s_vrf_fold%d.txt', params.dataset, type_ex, type_id, fold_id), 'a+');
        fprintf(fid, 'feat: expr, auc_val = %g, auc_test = %g (%s)\n', auc_val_list_ex(fold_id), auc_ts_list_ex(fold_id), params.fname);
        fprintf(fid, 'feat: id, auc_val = %g, auc_test = %g (%s)\n\n', auc_val_list_id(fold_id), auc_ts_list_id(fold_id), params.fname);
        fclose(fid);
    end
    
    fprintf('[fold %d] feat: expr, task: cls, acc_val = %g, acc_test = %g\n', fold_id, max(acc_val_list_ex(fold_id, :)), acc_ts_ex);
    fprintf('[fold %d] feat: id, task: cls, acc_val = %g, acc_test = %g\n', fold_id, max(acc_val_list_id(fold_id, :)), acc_ts_id);
    fprintf('[fold %d] feat: expr, task: acc_vrf, val = %g, acc_test = %g\n', fold_id, auc_val_list_ex(fold_id), auc_ts_list_ex(fold_id));
    fprintf('[fold %d] feat: id, task: acc_vrf, val = %g, acc_test = %g\n', fold_id, auc_val_list_id(fold_id), auc_ts_list_id(fold_id));
    
    % clear variables for memory efficiency
    clear htrain_ex htrain_id hval_ex hval_id htest_ex htest_id hlab_ex hlab_id;
end

% ---------------------------
% evaluate classification
% on test set
% ---------------------------

if length(foldlist) == 5,
    acc_val_list_ex = mean(acc_val_list_ex, 1);
    [acc_val_ex, id] = max(acc_val_list_ex);
    bestC_ex = Clist(id);
    acc_ts_ex = mean(acc_ts_list_ex(:, id));
    acc_std_ex = std(acc_ts_list_ex(:, id));
    
    acc_val_list_id = mean(acc_val_list_id, 1);
    [acc_val_id, id] = max(acc_val_list_id);
    bestC_id = Clist(id);
    acc_ts_id = mean(acc_ts_list_id(:, id));
    acc_std_id = std(acc_ts_list_id(:, id));
    
    fprintf('feat: expr, task: cls, retrain: no, acc_val = %g, acc_test = %g (std = %g) (%s)\n', acc_val_ex, acc_ts_ex, acc_std_ex, params.fname);
    fprintf('feat: id, task: cls, retrain: no, acc_val = %g, acc_test = %g (std = %g) (%s)\n', acc_val_id, acc_ts_id, acc_std_id, params.fname);
    
    fid = fopen(sprintf('log/%s_%s_%s_cls.txt', params.dataset, type_ex, type_id), 'a+');
    fprintf(fid, 'feat: expr, retrain: no, acc_val = %g, acc_test = %g (std = %g) (%s)\n', acc_val_ex, acc_ts_ex, acc_std_ex, params.fname);
    fprintf(fid, 'feat: id, retrain: no, acc_val = %g, acc_test = %g (std = %g) (%s)\n\n', acc_val_id, acc_ts_id, acc_std_id, params.fname);
    fclose(fid);
    
    fprintf('evaluate on test set, while retraining SVM with validation set\n');
    
    acc_ts_list_ex = zeros(length(foldlist), 1);
    acc_ts_list_id = zeros(length(foldlist), 1);
    
    for fold_id = foldlist,
        % split data into train/val/test folds
        tr_id = folds(:, fold_id) == 1;
        val_id = folds(:, fold_id) == 2;
        
        xtrain = xlab(:, tr_id);
        ytrain_ex = ylab_ex(tr_id);
        ytrain_id = ylab_id(tr_id);
        xval = xlab(:, val_id);
        yval_ex = ylab_ex(val_id);
        
        % ----------------------------
        % load disBM that was trained above.
        % ----------------------------
        
        params.fold = fold_id;
        params.optgpu = 0;
        
        % This call just loads the existing model (does not re-train).
        [weights, params] = disbm_train(xtrain, ytrain_ex, ytrain_id, xunlab, params, xval, yval_ex, yval_id);
        infer = disbm_infer(weights, params);
        
        % inference.
        [hlab_ex, hlab_id] = infer(xlab);
        
        % ----------------------------
        % evaluate classification. When we pass in bestC_ex, bestC_id rather than
        % an array of Cvals for validation, the final result is computed.
        % Note that we only need to do emotion recognition here because we didn't
        % use SVM for face verification.
        % ----------------------------
        
        tr_id = folds_ex(:, fold_id) == 1;
        val_id = folds_ex(:, fold_id) == 2;
        ts_id = folds_ex(:, fold_id) == 3;
        
        ytrain_ex = ylab_ex(tr_id);
        yval_ex = ylab_ex(val_id);
        ytest_ex = ylab_ex(ts_id);
        
        % use expression units
        htrain_ex = hlab_ex(:, tr_id);
        hval_ex = hlab_ex(:, val_id);
        htest_ex = hlab_ex(:, ts_id);
        
        [~, acc_test] = eval_cls(bestC_ex, htrain_ex, ytrain_ex, hval_ex, yval_ex, htest_ex, ytest_ex);
        acc_ts_list_ex(fold_id) = acc_test;
        
        % use identity units
        htrain_id = hlab_id(:, tr_id);
        hval_id = hlab_id(:, val_id);
        htest_id = hlab_id(:, ts_id);
        
        [~, acc_test] = eval_cls(bestC_id, htrain_id, ytrain_ex, hval_id, yval_ex, htest_id, ytest_ex);
        acc_ts_list_id(fold_id) = acc_test;
        
        % clear variables for memory efficiency
        clear htrain_ex htrain_id hval_ex hval_id htest_ex htest_id hlab_ex hlab_id;
    end
    
    % ---------------------------
    % save results
    % ---------------------------
    
    acc_ts_ex = mean(acc_ts_list_ex);
    acc_std_ex = std(acc_ts_list_ex);
    acc_ts_id = mean(acc_ts_list_id);
    acc_std_id = std(acc_ts_list_id);
    auc_val_ex = mean(auc_val_list_ex);
    auc_ts_ex = mean(auc_ts_list_ex);
    auc_std_ex = std(auc_ts_list_ex);
    auc_val_id = mean(auc_val_list_id);
    auc_ts_id = mean(auc_ts_list_id);
    auc_std_id = std(auc_ts_list_id);
    
    fprintf('feat: expr, task: cls, retrain: yes, acc_val = %g, acc_test = %g (std = %g) (%s)\n', acc_val_ex, acc_ts_ex, acc_std_ex, params.fname);
    fprintf('feat: id, task: cls, retrain: yes, acc_val = %g, acc_test = %g (std = %g) (%s)\n', acc_val_id, acc_ts_id, acc_std_id, params.fname);
    fprintf('feat: expr, task: vrf, auc_val = %g, auc_test = %g (std = %g) (%s)\n', auc_val_ex, auc_ts_ex, auc_std_ex, params.fname);
    fprintf('feat: id, task: vrf, auc_val = %g, auc_test = %g (std = %g) (%s)\n', auc_val_id, auc_ts_id, auc_std_id, params.fname);
    
    fid = fopen(sprintf('log/%s_%s_%s_cls.txt', params.dataset, type_ex, type_id), 'a+');
    fprintf(fid, 'feat: expr, retrain: yes, acc_val = %g, acc_test = %g (std = %g) (%s)\n', acc_val_ex, acc_ts_ex, acc_std_ex, params.fname);
    fprintf(fid, 'feat: id, retrain: yes, acc_val = %g, acc_test = %g (std = %g) (%s)\n\n', acc_val_id, acc_ts_id, acc_std_id, params.fname);
    fclose(fid);
    
    fid = fopen(sprintf('log/%s_%s_%s_vrf.txt', params.dataset, type_ex, type_id), 'a+');
    fprintf(fid, 'feat: expr, auc_val = %g, auc_test = %g (std = %g) (%s)\n', auc_val_ex, auc_ts_ex, auc_std_ex, params.fname);
    fprintf(fid, 'feat: id, auc_val = %g, auc_test = %g (std = %g) (%s)\n\n', auc_val_id, auc_ts_id, auc_std_id, params.fname);
    fclose(fid);
end

return;
