% Evaluate supervised version of disBM.
% Assumes labels are used for factor A.
function [err, cost] = evaluate_label_disbm(data, label, weights, nmf)

if ~exist('nmf', 'var'),
    nmf = 200;
end

% make parameters in double precision
weights = gpu2cpu_struct(weights);

data = double(data);
N = size(data, 2);
batchsize = 100;

err = 0;
cost = 0;
for i = 1:ceil(N/batchsize),
    batchidx = (i-1)*batchsize+1:min(i*batchsize, N);
    data_batch = data(:, batchidx);
    label_batch = label(batchidx);
    
    [pred, c] = evaluate_disbm_sub(data_batch, label_batch, weights, nmf);
    
    err = err + sum(label_batch(:) ~= pred(:));
    cost = cost + c;
end

err = err / N;
cost = cost / N;

return;


function [pred, c] = evaluate_disbm_sub(data, label, weights, nmf)

batchsize = size(data, 2);
numlab = size(weights.hidlab, 2);
label_mult = multi_output(label, numlab);

wvha_v = weights.vishid_a'*data;
wvhb_v = weights.vishid_b'*data;
wvf = weights.visfac'*data;

hbiasmat_a = repmat(weights.hidbias_a, [1 batchsize]);
hbiasmat_b = repmat(weights.hidbias_b, [1 batchsize]);
lbiasmat = repmat(weights.labbias, [1 batchsize]);

lab = lbiasmat;
whl_l = weights.hidlab*lab;

% initialize with direct visible-hidden connection
hb = sigmoid(2*wvhb_v + hbiasmat_b);
whbf = weights.hidfac_b'*hb;
hb_old = hb;

% mean-field update (update a -> b)
for i = 1:nmf,
    % hidden
    ha = sigmoid(weights.hidfac_a*(wvf.*whbf) + wvha_v + whl_l + hbiasmat_a);
    whaf = weights.hidfac_a'*ha;
    
    % label
    lab = weights.hidlab'*ha + lbiasmat;
    lab = exp(bsxfun(@minus, lab, max(lab, [], 1)));
    lab = bsxfun(@rdivide, lab, sum(lab, 1));
    whl_l = weights.hidlab*lab;
    
    % share hidden unit with same ids
    hb = sigmoid(2*weights.hidfac_b*(wvf.*whaf) + 2*wvhb_v + hbiasmat_b);
    whbf = weights.hidfac_b'*hb;
    
    diff = mean(abs(hb(:) - hb_old(:)));
    if diff < 1e-6,
        break;
    else
        hb_old = hb;
    end
end

[~, pred] = max(lab, [], 1);
c = computeLoss(lab, label_mult);

return;


function [f, out, dout] = computeLoss(yr, y)
yr = max(yr, 1e-8);
yr = min(yr, 1-1e-8);

%%% compute loss function
f = - sum(sum(y.*log(yr)));

if nargout >= 2,
    out = yr;
    dout = -y.*(1./out);
end

return;
