
function train_cars_dis(use_gpu, masked, alpha, init_from)

if ~exist('use_gpu', 'var'),
    use_gpu = -1;
end

if ~exist('masked', 'var'),
    masked = 0;
end

if ~exist('alpha', 'var'),
    alpha = 0.01;
end

solver_enc = 'analogy_enc_solver.prototxt';
solver_dec = 'analogy_dec_solver.prototxt';

sampler = @(files,pars) sample_disentangle(files,pars);
update = @(batch_images,batch_masks,pars) update_dis(batch_images, batch_masks, pars);    

addpath('../../matlab/caffe');
addpath('../util');
addpath('../model');

subdir = 'cars_dis';
solver_enc = sprintf('results/%s/%s', subdir, solver_enc);
solver_dec = sprintf('results/%s/%s', subdir, solver_dec);

if exist('init_from', 'var'),
    matcaffe_init_cars_train(use_gpu, solver_enc, solver_dec, init_from);
else
    matcaffe_init_cars_train(use_gpu, solver_enc, solver_dec);
end

%% Load data.
datadir = '../data/cars';
%ids = 1:100; % 101-149 validation, 
%ids = 150:199; %testing.
ids = 1:149;
trainfiles = get_car_files(datadir, ids);

%% Init model.
pars = struct('numstep', 300000, ...
              'numhid', 1024, ...
              'vdim', 64, ...
              'alpha', alpha, ...
              'masked', masked, ...
              'lambda_mask', 1, ...
              'lambda_rgb', 10, ...
              'batchsize', 25);
pars.fname = sprintf('analogy_cars_dis_a%.2g_m%d_%s', alpha, masked, datestr(now,30));
pars.fname_png = sprintf('images/%s/%s.png', subdir, pars.fname);
mkdir(sprintf('images/%s', subdir))
pars.fname_mat = sprintf('results/%s/%s.mat', subdir, pars.fname);
pars.id_idx = 1:512;
pars.pose_idx = 513:1024;
disp(pars);
%% Train.
caffe('set_phase_train');
for b = 1:pars.numstep,
    [ batch_sprites, batch_masks ] = sampler(trainfiles, pars);
    [ pred_sprite, ~, cost ] = update(batch_sprites, batch_masks, pars);

    % demo analogy.
    if mod(b,10)==0,
        fprintf(1, 'step %d of %d, cost=%g\n', ...
                b, pars.numstep, cost);
        subplot(1,4,1);
        imagesc(permute(batch_sprites.X1(:,:,:,1), [ 1 2 3 4 ]));
        subplot(1,4,2);
        imagesc(permute(batch_sprites.X2(:,:,:,1), [ 1 2 3 4 ]));
        subplot(1,4,3);
        imagesc(permute(batch_sprites.X3(:,:,:,1), [ 1 2 3 4 ]));
        subplot(1,4,4);
        imagesc(permute(mytrim(pred_sprite(:,:,:,1), 0, 1), [ 1 2 3 4 ]));
        drawnow;
        print('-dpng', pars.fname_png);
    end

    if mod(b,1000)==0,
        model = struct('enc', caffe('get_weights', 0), ...
                       'dec', caffe('get_weights', 1), ...
                       'pars', pars);
        save(pars.fname_mat, '-struct', 'model');        
    end
end

end


function [ files ] = get_car_files(datadir, ids)

files = cell(numel(ids), 1);

for c = 1:numel(ids),
  fname = sprintf('%s/car_%.3d_mesh.mat', datadir, ids(c));
  files{c} = fname;
end

end

function [ pred_data, pred_mask, cost ] = ...
      update_dis(batch_data, batch_masks, pars)
    Xin = { single(batch_data.X1); ...
            single(batch_data.X2); ...
            single(batch_data.X3); ...
            single(batch_data.X3) };
    S = batch_data.S;
    S2 = zeros(pars.numhid,pars.batchsize,'single');

    % Encoder.
    result = caffe('forward', Xin, 0);
    hid2 = squeeze(result{1});
    hid3 = squeeze(result{2});
    hid1 = squeeze(result{3});

    top = S .* hid1 + (1 - S) .* hid2;
    trans = 0*top;

    % Hidden unit regularizer.
    err_hid = top - hid3;
    cost_hid = pars.alpha*0.5*(err_hid(:)'*err_hid(:))/pars.batchsize;
    delta_hid = pars.alpha*err_hid/pars.batchsize;

    % Decoder.
    result = caffe('forward', { trans; top; S2 }, 1);
    hid_new = result{1};
    pred_mask = result{2};
    pred_data = result{3};

    % Reconstruction cost.
    [ cost_recon, delta_data, delta_mask ] = ...
        grad_recon(pred_data, pred_mask, batch_data.X3, ...
                   batch_masks.X3, pars);
    delta_recon = caffe('backward', { 0*hid_new; delta_mask; delta_data }, 1);
    caffe('update', 1);
    delta_recon = squeeze(delta_recon{2});
    % X2 and X3 are always the same, so only update units
    % extracted from X1 in order to avoid overfitting.
    delta_hid1 = S .* (delta_recon + delta_hid);
    delta_hid2 = (1 - S) .* (delta_recon + delta_hid);
    delta_hid3 = -delta_hid;

    % Update encoder.
    caffe('backward', { delta_hid2; delta_hid3; delta_hid1; 0*delta_hid1 }, 0);
    caffe('update', 0);

    cost = cost_recon + cost_hid;
end


function [ batch_images, batch_masks ] = sample_disentangle(files, pars)
    batch_images = struct;
    batch_masks = struct;

    % randomly pick two cars.
    while(1),
        idx = randsample(size(files,1),2,1);
        car1 = idx(1);
        car2 = idx(2);
        try
            data1 = load(files{car1});
            data2 = load(files{car2});
            break;
        catch
            continue
        end
    end

    pitch1 = randsample(24,pars.batchsize,1);
    el1 = randsample(4,pars.batchsize,1);
    pitch2 = randsample(24,pars.batchsize,1);
    el2 = randsample(4,pars.batchsize,1);
    idx1 = sub2ind([24 4], pitch1, el1);
    idx2 = sub2ind([24 4], pitch2, el2);

    batch_images.X1 = data1.im(:,:,:,idx1);
    batch_masks.X1 = data1.mask(:,:,idx1);
    batch_images.X2 = data2.im(:,:,:,idx2);
    batch_masks.X2 = data2.mask(:,:,idx2);
    batch_images.X3 = data1.im(:,:,:,idx2);
    batch_masks.X3 = data1.mask(:,:,idx2);

    batch_images.S = zeros(pars.numhid,pars.batchsize);
    batch_images.S(pars.id_idx,:) = 1; % id from X1.
    batch_images.S(pars.pose_idx,:) = 0;  % pose from X2.

    sz = [ 64 64 ];
    batch_images.X1 = imresize(single(batch_images.X1)/255., sz);
    batch_images.X2 = imresize(single(batch_images.X2)/255., sz);
    batch_images.X3 = imresize(single(batch_images.X3)/255., sz);
    batch_images.target = batch_images.X3;

    batch_masks.X1 = round(imresize(single(batch_masks.X1), sz));
    batch_masks.X1 = permute(batch_masks.X1,[1 2 4 3]);
    batch_masks.X2 = round(imresize(single(batch_masks.X2), sz));
    batch_masks.X2 = permute(batch_masks.X2,[1 2 4 3]);
    batch_masks.X3 = round(imresize(single(batch_masks.X3), sz));
    batch_masks.X3 = permute(batch_masks.X3,[1 2 4 3]);
    batch_masks.target = batch_masks.X3;
end

