function [ cost, deltas, acc ] = grad_hid_disc(ref, out, query, target, labels, pars)

deltas = cell(4,1);
deltas{1} = 0*ref;
deltas{2} = 0*out;
deltas{3} = 0*query;
deltas{4} = 0*target;

labels.L1 = labels.L1 + 1;
labels.L2 = labels.L2 + 1;
labels.L3 = labels.L3 + 1;
labels.L4 = labels.L4 + 1;

if (numel(labels.L1) == pars.batchsize),
    labels.L1 = repmat(labels.L1(:),1,pars.batchsize);
    labels.L2 = repmat(labels.L2(:),1,pars.batchsize);
    labels.L3 = repmat(labels.L3(:),1,pars.batchsize);
    labels.L4 = repmat(labels.L4(:),1,pars.batchsize);
end

idx = 1;
cost = 0;
acc = 0;
for l = 1:length(pars.card),
    curidx = idx:(idx + pars.card(l) - 1);
    L = oneofc(labels.L1(l,:)', pars.card(l));
    [ cost_l, delta_ref, acc_ref ] = grad_softmax(L, ref(curidx, :), pars);
    cost = cost + cost_l;
    deltas{1}(curidx,:) = deltas{1}(curidx,:) + delta_ref;

    L = oneofc(labels.L2(l,:)', pars.card(l));
    [ cost_l, delta_out, acc_out ] = grad_softmax(L, out(curidx, :), pars);
    cost = cost + cost_l;
    deltas{2}(curidx,:) = deltas{2}(curidx,:) + delta_out;

    L = oneofc(labels.L3(l,:)', pars.card(l));
    [ cost_l, delta_query, acc_query ] = grad_softmax(L, query(curidx, :), pars);
    cost = cost + cost_l;
    deltas{3}(curidx,:) = deltas{3}(curidx,:) + delta_query;

    L = oneofc(labels.L4(l,:)', pars.card(l));
    [ cost_l, delta_target, acc_target ] = grad_softmax(L, target(curidx, :), pars);
    cost = cost + cost_l;
    deltas{4}(curidx,:) = deltas{4}(curidx,:) + delta_target;

    idx = idx + pars.card(l);
    acc = acc + (acc_ref + acc_out + acc_query + acc_target);
end

acc = acc / (length(pars.card) * 4);

deltas{1} = pars.alpha*deltas{1};
deltas{2} = pars.alpha*deltas{2};
deltas{3} = pars.alpha*deltas{3};
deltas{4} = pars.alpha*deltas{4};

cost = pars.alpha * cost;

end

function [ cost, delta, acc ] = grad_softmax(L, P, pars)

P = exp(bsxfun(@minus, P, max(P, [], 1)));
pred = bsxfun(@rdivide, P, sum(P));
[~,pred_quantized] = max(pred,[],1);
[ yvec, ~] = find(L);
acc = mean(pred_quantized(:) == yvec(:));
cost = (-1/pars.batchsize) * sum(sum(L.*log(pred)));
delta = (1/pars.batchsize) * (pred - L);

end

