function [ grad, cost ] = manifold_grad(weights,D1,D2,D3,params,gradcheck)

visfac = weights.visfac;
hidfac_a = weights.hidfac_a;
hidfac_b = weights.hidfac_b;
vishid_a = weights.vishid_a;
vishid_b = weights.vishid_b;
hidbias_a = weights.hidbias_a;
hidbias_b = weights.hidbias_b;

kmf = params.kmf;
kman_a = params.kman_a;
kman_b = params.kman_b;
thresh_a = params.thresh_a;
thresh_b = params.thresh_b;

numdata = size(D1,2);
A_history1 = 0*repmat(hidbias_a,[1,numdata,kmf+1]);
B_history1 = 0*repmat(hidbias_b,[1,numdata,kmf+1]);
A_history2 = 0*repmat(hidbias_a,[1,numdata,kmf+1]);
B_history2 = 0*repmat(hidbias_b,[1,numdata,kmf+1]);
A_history3 = 0*repmat(hidbias_a,[1,numdata,kmf+1]);
B_history3 = 0*repmat(hidbias_b,[1,numdata,kmf+1]);

A1 = sigmoid(bsxfun(@plus,vishid_a*D1,hidbias_a));
A_history1(:,:,1) = A1;
B1 = sigmoid(bsxfun(@plus,vishid_b*D1,hidbias_b));
B_history1(:,:,1) = B1;
A2 = sigmoid(bsxfun(@plus,vishid_a*D2,hidbias_a));
A_history2(:,:,1) = A2;
B2 = sigmoid(bsxfun(@plus,vishid_b*D2,hidbias_b));
B_history2(:,:,1) = B2;
A3 = sigmoid(bsxfun(@plus,vishid_a*D3,hidbias_a));
A_history3(:,:,1) = A3;
B3 = sigmoid(bsxfun(@plus,vishid_b*D3,hidbias_b));
B_history3(:,:,1) = B3;

v1tmp = visfac*D1;
v2tmp = visfac*D2;
v3tmp = visfac*D3;

vishid_ad1 = vishid_a*D1;
vishid_bd1 = vishid_b*D1;
vishid_ad2 = vishid_a*D2;
vishid_bd2 = vishid_b*D2;
vishid_ad3 = vishid_a*D3;
vishid_bd3 = vishid_b*D3;

if exist('gradcheck','var') && gradcheck==1,
    order = 0;
else
    order = rand>0.5;
end
if order==0,
    for k = 1:kmf,
        A1 = sigmoid(bsxfun(@plus, vishid_ad1 + hidfac_a'*(v1tmp.*(hidfac_b*B1)), hidbias_a));
        B1 = sigmoid(bsxfun(@plus, vishid_bd1 + hidfac_b'*(v1tmp.*(hidfac_a*A1)), hidbias_b));
        A_history1(:,:,k+1) = A1;
        B_history1(:,:,k+1) = B1;

        A2 = sigmoid(bsxfun(@plus, vishid_ad2 + hidfac_a'*(v2tmp.*(hidfac_b*B2)), hidbias_a));
        B2 = sigmoid(bsxfun(@plus, vishid_bd2 + hidfac_b'*(v2tmp.*(hidfac_a*A2)), hidbias_b));
        A_history2(:,:,k+1) = A2;
        B_history2(:,:,k+1) = B2;

        A3 = sigmoid(bsxfun(@plus, vishid_ad3 + hidfac_a'*(v3tmp.*(hidfac_b*B3)), hidbias_a));
        B3 = sigmoid(bsxfun(@plus, vishid_bd3 + hidfac_b'*(v3tmp.*(hidfac_a*A3)), hidbias_b));
        A_history3(:,:,k+1) = A3;
        B_history3(:,:,k+1) = B3;
    end
else
    for k = 1:kmf,
        B1 = sigmoid(bsxfun(@plus, vishid_bd1 + hidfac_b'*(v1tmp.*(hidfac_a*A1)), hidbias_b));
        A1 = sigmoid(bsxfun(@plus, vishid_ad1 + hidfac_a'*(v1tmp.*(hidfac_b*B1)), hidbias_a));
        A_history1(:,:,k+1) = A1;
        B_history1(:,:,k+1) = B1;

        B2 = sigmoid(bsxfun(@plus, vishid_bd2 + hidfac_b'*(v2tmp.*(hidfac_a*A2)), hidbias_b));
        A2 = sigmoid(bsxfun(@plus, vishid_ad2 + hidfac_a'*(v2tmp.*(hidfac_b*B2)), hidbias_a));
        A_history2(:,:,k+1) = A2;
        B_history2(:,:,k+1) = B2;

        B3 = sigmoid(bsxfun(@plus, vishid_bd3 + hidfac_b'*(v3tmp.*(hidfac_a*A3)), hidbias_b));
        A3 = sigmoid(bsxfun(@plus, vishid_ad3 + hidfac_a'*(v3tmp.*(hidfac_b*B3)), hidbias_a));
        A_history3(:,:,k+1) = A3;
        B_history3(:,:,k+1) = B3;
    end    
end

dvisfac = 0*visfac;
dhidfac_a = 0*hidfac_a;
dhidfac_b = 0*hidfac_b;
dvishid_a = 0*vishid_a;
dvishid_b = 0*vishid_b;
dhidbias_a = 0*hidbias_a;
dhidbias_b = 0*hidbias_b;
cost = 0;

% factor of variation 1: A1 and A2 should be similar, A1 and A3 different.
err_same = A1-A2;
err_diff = A1-A3;
diff_dist = sqrt(sum(err_diff.^2,1));
over_thresh = thresh_a - diff_dist;
%fprintf(1,'%d points too close...\n', sum(over_thresh>0));
a_same = kman_a*0.5*sum(sum(err_same.^2));
a_diff = kman_a*0.5*sum(max(0,over_thresh).^2);
cost = cost + a_same + a_diff;
dA1 = err_same.*(A1.*(1-A1));
dA2 = -err_same.*(A2.*(1-A2));
ddiff = -(over_thresh>0).*(thresh_a-diff_dist).*(sum(err_diff.^2,1).^(-.5));
ddiff = bsxfun(@times, ddiff, err_diff);
dA3 = -ddiff.*(A3.*(1-A3));
dA1 = dA1 + ddiff.*(A1.*(1-A1));
% Scale according to penalty term.
dA1 = kman_a*dA1;
dA2 = kman_a*dA2;
dA3 = kman_a*dA3;

% factor of variation 2: B1 and B3 should be similar, B1 and B2 different.
err_same = B1-B3;
err_diff = B1-B2;
diff_dist = sqrt(sum(err_diff.^2,1));
over_thresh = thresh_b - diff_dist;
b_same = kman_b*0.5*sum(sum(err_same.^2));
b_diff = kman_b*0.5*sum(max(0,over_thresh).^2);
cost = cost + b_same + b_diff;
dB1 = err_same.*(B1.*(1-B1));
dB3 = -err_same.*(B3.*(1-B3));
%fprintf(1,'%d points too close...\n', sum(over_thresh>0));
ddiff = -(over_thresh>0).*(thresh_b-diff_dist).*(sum(err_diff.^2,1).^(-.5));
ddiff = bsxfun(@times, ddiff, err_diff);
dB2 = -ddiff.*(B2.*(1-B2));
dB1 = dB1 + ddiff.*(B1.*(1-B1));
% Scale.
dB1 = kman_b*dB1;
dB2 = kman_b*dB2;
dB3 = kman_b*dB3;

hA = { A_history1, A_history2, A_history3 };
hB = { B_history1, B_history2, B_history3 };
D = { D1, D2, D3 };
gA = { dA1, dA2, dA3 };
gA_sav = gA;
gB = { dB1, dB2, dB3 };
gB_sav = gB;
V = { v1tmp, v2tmp, v3tmp };

kstop = max(2,kmf-4);
% RNN backprop, starting from A.
for n = 1:3,
    %for k = (kmf+1):-1:2,
    for k = (kmf+1):-1:kstop,
        % A backprop.
        dhidbias_a = dhidbias_a + sum(gA{n},2);
        btmp = hidfac_b*hB{n}(:,:,k-1+order);
        dhidfac_a = dhidfac_a + (btmp.*V{n})*gA{n}';
        dhidfac_b = dhidfac_b + ((hidfac_a*gA{n}).*V{n})*hB{n}(:,:,k-1+order)';
        dvisfac = dvisfac + (btmp.*(hidfac_a*gA{n}))*D{n}';
        dvishid_a = dvishid_a + gA{n}*D{n}';

        % B backprop.
        gB{n} = hidfac_b'*(V{n}.*(hidfac_a*gA{n}));
        gB{n} = gB{n}.*hB{n}(:,:,k-1+order).*(1-hB{n}(:,:,k-1+order));
        dhidbias_b = dhidbias_b + sum(gB{n},2);
        dvishid_b = dvishid_b + gB{n}*D{n}';
        if (k==2)&&(order==0), break; end
        atmp = hidfac_a*hA{n}(:,:,k-1);
        dhidfac_b = dhidfac_b + (atmp.*V{n})*gB{n}';
        dhidfac_a = dhidfac_a + ((hidfac_b*gB{n}).*V{n})*hA{n}(:,:,k-1)';
        dvisfac = dvisfac + (atmp.*(hidfac_b*gB{n}))*D{n}';

        gA{n} = hidfac_a'*(V{n}.*(hidfac_b*gB{n}));
        gA{n} = gA{n}.*hA{n}(:,:,k-1).*(1-hA{n}(:,:,k-1));
    end
    if (order==1) || (kmf==0),
        dhidbias_a = dhidbias_a + sum(gA{n},2);
        dvishid_a = dvishid_a + gA{n}*D{n}';
    end
end

% Start from the top.
gA = gA_sav;
gB = gB_sav;

% RNN backprop, starting from B.
for n = 1:3,
    %for k = (kmf+1):-1:2,
    for k = (kmf+1):-1:kstop,
        % B backprop.
        dhidbias_b = dhidbias_b + sum(gB{n},2);
        atmp = hidfac_a*hA{n}(:,:,k-order);
        dhidfac_b = dhidfac_b + (atmp.*V{n})*gB{n}';
        dhidfac_a = dhidfac_a + ((hidfac_b*gB{n}).*V{n})*hA{n}(:,:,k-order)';
        dvisfac = dvisfac + (atmp.*(hidfac_b*gB{n}))*D{n}';
        dvishid_b = dvishid_b + gB{n}*D{n}';
        
        % A backprop.
        gA{n} = hidfac_a'*(V{n}.*(hidfac_b*gB{n}));
        gA{n} = gA{n}.*hA{n}(:,:,k-order).*(1-hA{n}(:,:,k-order));
        dhidbias_a = dhidbias_a + sum(gA{n},2);
        dvishid_a = dvishid_a + gA{n}*D{n}';
        if (k==2)&&(order==1), break; end
        btmp = hidfac_b*hB{n}(:,:,k-1);
        dhidfac_a = dhidfac_a + (btmp.*V{n})*gA{n}';
        dhidfac_b = dhidfac_b + ((hidfac_a*gA{n}).*V{n})*hB{n}(:,:,k-1)';
        dvisfac = dvisfac + (btmp.*(hidfac_a*gA{n}))*D{n}';

        gB{n} = hidfac_b'*(V{n}.*(hidfac_a*gA{n}));
        gB{n} = gB{n}.*hB{n}(:,:,k-1).*(1-hB{n}(:,:,k-1));
    end
    if (order==0) || (kmf==0),
        dhidbias_b = dhidbias_b + sum(gB{n},2);
        dvishid_b = dvishid_b + gB{n}*D{n}';
    end
end

grad = struct();
grad.visfac = dvisfac;
grad.hidfac_a = dhidfac_a;
grad.hidfac_b = dhidfac_b;
grad.vishid_a = dvishid_a;
grad.vishid_b = dvishid_b;
grad.hidbias_a = dhidbias_a;
grad.hidbias_b = dhidbias_b;
