function [ha, hb] = disbm_infer_label(data, weights, params, nmf)

if ~exist('nmf', 'var'),
    nmf = params.nmf;
end

% make all in double precision
weights = gpu2cpu_struct(weights);

N = size(data, 2);
batchsize = 100;

numhid_a = size(weights.hidfac_a, 1);
numhid_b = size(weights.hidfac_b, 1);

ha = zeros(numhid_a, N);
hb = zeros(numhid_b, N);

for i = 1:ceil(N/batchsize),
    data_batch = data(:, (i-1)*batchsize+1:min(i*batchsize, N));
    [ha_batch, hb_batch] = disbm_infer_label_sub(data_batch, weights, nmf);
    
    ha(:, (i-1)*batchsize+1:min(i*batchsize, N)) = ha_batch;
    hb(:, (i-1)*batchsize+1:min(i*batchsize, N)) = hb_batch;
end

return;

function [ha, hb] = disbm_infer_label_sub(data, weights, nmf)

batchsize = size(data, 2);
data = double(data);

% inference
wvha_v = weights.vishid_a'*data;
wvhb_v = weights.vishid_b'*data;
wvf = weights.visfac'*data;

hbiasmat_a = repmat(weights.hidbias_a, [1 batchsize]);
hbiasmat_b = repmat(weights.hidbias_b, [1 batchsize]);
lbiasmat = repmat(weights.labbias, [1 batchsize]);

% initialize with visible-hidden connection
hb = sigmoid(2*wvhb_v + hbiasmat_b);
whbf = weights.hidfac_b'*hb;
hb_old = hb;

lab = lbiasmat;
lab = exp(bsxfun(@minus, lab, max(lab, [], 1)));
lab = bsxfun(@rdivide, lab, sum(lab, 1));
whl = weights.hidlab*lab;

% mean-field update (update a -> b)
for i = 1:nmf,
    ha = sigmoid(weights.hidfac_a*(wvf.*whbf) + wvha_v + whl + hbiasmat_a);
    whaf = weights.hidfac_a'*ha;
    
    lab = weights.hidlab'*ha + lbiasmat;
    lab = exp(bsxfun(@minus, lab, max(lab, [], 1)));
    lab = bsxfun(@rdivide, lab, sum(lab, 1));
    whl = weights.hidlab*lab;
    
    hb = sigmoid(2*weights.hidfac_b*(wvf.*whaf) + 2*wvhb_v + hbiasmat_b);
    whbf = weights.hidfac_b'*hb;
    
    diff = mean(mean(abs(hb - hb_old), 2), 1);
    if diff < 1e-8,
        break;
    else
        hb_old = hb;
    end
end

return;
