
function train_cars_analogy(use_gpu, masked, alpha, init_from)

if ~exist('use_gpu', 'var'),
    use_gpu = -1;
end

if ~exist('masked', 'var'),
    masked = 0;
end

if ~exist('alpha', 'var'),
    alpha = 0.01;
end

solver_enc = 'analogy_enc_solver.prototxt';
solver_dec = 'analogy_dec_solver.prototxt';

subdir = 'cars_analogy';
sampler = @(files,pars) sample_analogy(files,pars);
update = @(batch_images,batch_masks,pars) update_mul(batch_images, batch_masks, pars);

addpath('../../matlab/caffe');
addpath('../util');
addpath('../model');

solver_enc = sprintf('results/%s/%s', subdir, solver_enc);
solver_dec = sprintf('results/%s/%s', subdir, solver_dec);

if exist('init_from', 'var'),
    matcaffe_init_cars_train(use_gpu, solver_enc, solver_dec, init_from);
else
    matcaffe_init_cars_train(use_gpu, solver_enc, solver_dec);
end

%% Load data.
datadir = '../data/cars';
ids = 1:100; % 101-149 validation, 
%ids = 150:199; %testing.
pitches = [6:-1:1,24:-1:19];
elev = 3;
trainfiles = get_car_files(datadir, ids);

%% Init model.
pars = struct('numstep', 300000, ...
              'numhid', 640, ...
              'vdim', 64, ...
              'alpha', alpha, ...
              'masked', masked, ...
              'lambda_mask', 1, ...
              'lambda_rgb', 10, ...
              'pitches', pitches, ...
              'elev', elev, ...
              'batchsize', 25);

pars.fname = sprintf('analogy_cars_a%.2g_m%d_%s', alpha, masked, datestr(now,30));
pars.fname_png = sprintf('images/%s/%s.png', subdir, pars.fname);
mkdir(sprintf('images/%s', subdir))
pars.fname_mat = sprintf('results/%s/%s.mat', subdir, pars.fname);
%pars.id_idx = 1:512;
%pars.pose_idx = 513:1024;
pars.id_idx = 1:512;
pars.pose_idx = 513:640;
disp(pars);

%% Train.
caffe('set_phase_train');
for b = 1:pars.numstep,
    [ batch_sprites, batch_masks ] = sampler(trainfiles, pars);
    [ pred_sprite, ~, cost ] = update(batch_sprites, batch_masks, pars);

    % demo analogy.
    if mod(b,10)==0,
        fprintf(1, 'step %d of %d, cost=%g\n', ...
                b, pars.numstep, cost);
        subplot(1,4,1);
        imagesc(permute(batch_sprites.X1(:,:,:,1), [ 1 2 3 4 ]));
        subplot(1,4,2);
        imagesc(permute(batch_sprites.X2(:,:,:,1), [ 1 2 3 4 ]));
        subplot(1,4,3);
        imagesc(permute(batch_sprites.X3(:,:,:,1), [ 1 2 3 4 ]));
        subplot(1,4,4);
        imagesc(permute(mytrim(pred_sprite(:,:,:,1), 0, 1), [ 1 2 3 4 ]));
        drawnow;
        print('-dpng', pars.fname_png);
    end

    if mod(b,1000)==0,
        model = struct('enc', caffe('get_weights', 0), ...
                       'dec', caffe('get_weights', 1), ...
                       'pars', pars);
        save(pars.fname_mat, '-struct', 'model');        
    end
end

end


function [ files ] = get_car_files(datadir, ids)

files = cell(numel(ids), 1);

for c = 1:numel(ids),
  fname = sprintf('%s/car_%.3d_mesh.mat', datadir, ids(c));
  files{c} = fname;
end

end

function [ pred_data, pred_mask, cost ] = ...
      update_mul(batch_data, batch_masks, pars)
    Xin = { single(batch_data.X1); ...
            single(batch_data.X2); ...
            single(batch_data.X3); ...
            single(batch_data.X4) };

    result = caffe('forward', Xin, 0);
    out = squeeze(result{1});
    query = squeeze(result{2});
    ref = squeeze(result{3});
    target = squeeze(result{4});

    trans = out - ref;
    trans = trans(pars.pose_idx,:);

    top = query;
    top_id = top(pars.id_idx,:);
    top_pose = top(pars.pose_idx,:);
    sw = ones(numel(pars.pose_idx),pars.batchsize,'single');

    % Decoder.
    inpq = cat(3,batch_data.X3,repmat(permute(trans,[4,3,1,2]),[64 64 1 1]));
    result = caffe('forward', { trans, top_id, top_pose, sw, inpq }, 1);
    hid_new = squeeze(result{1});
    pred_mask = result{2};
    pred_data = result{3};

    % Reconstruction cost.
    [ cost_recon, delta_data, delta_mask ] = ...
        grad_recon(pred_data, pred_mask, batch_data.X4, ...
                   batch_masks.X4, pars);

    % Hidden unit cost.
    err_hid = hid_new - target;
    cost_hid = pars.alpha*0.5*(err_hid(:)'*err_hid(:)) / pars.batchsize;
    delta_hid = pars.alpha*err_hid / pars.batchsize;

    delta_recon = caffe('backward', { delta_hid; delta_mask; delta_data }, 1);
    caffe('update', 1);
    delta_trans = squeeze(delta_recon{1});
    delta_top_id = squeeze(delta_recon{2});
    delta_top_pose = squeeze(delta_recon{3});
    delta_inpq = delta_recon{5};
    %delta_trans = delta_trans + squeeze(mean(mean(delta_inpq(:,:,4:131,:),1),2));
    delta_trans = delta_trans + squeeze(sum(sum(delta_inpq(:,:,4:131,:),1),2));

    delta_out = zeros(pars.numhid,pars.batchsize,'single');
    delta_query = zeros(pars.numhid,pars.batchsize,'single');
    delta_ref = zeros(pars.numhid,pars.batchsize,'single');

    delta_out(pars.pose_idx,:) =   delta_trans;
    delta_query(pars.pose_idx,:) = delta_top_pose;
    delta_query(pars.id_idx,:) =   delta_top_id;
    delta_ref(pars.pose_idx,:) =  -delta_trans;
    delta_target =                -delta_hid;

    % Update encoder.
    caffe('backward', { delta_out; delta_query; delta_ref; delta_target }, 0);
    caffe('update', 0);

    cost = cost_recon + cost_hid;
end

function [ batch_images, batch_masks ] = sample_analogy(files, pars)
    batch_images = struct;
    batch_masks = struct;

    % randomly pick two cars.
    while(1),
        idx = randsample(size(files,1),2,1);
        car1 = idx(1);
        car2 = idx(2);
        if rand < 0.2,
            car2 = car1;
        end
        try
            data1 = load(files{car1});
            data2 = load(files{car2});
            break;
        catch
            continue
        end
    end

    % Pitch and elevation of car 1.
    ept1 = randsample(12,pars.batchsize,1);
    offset = 0*ept1;
    for b = 1:pars.batchsize,
        lb = min(2,ept1(b)-1);
        ub = min(2,12-ept1(b));
        %offset(b) = randsample(-lb:1:ub,1);
        offset(b) = randsample([-lb:1:-1, 1:1:ub],1);
        if rand() < .05,
            offset(b) = 0;
        end
        %offset_ix = find(mnrnd(1,[0.225,0.225,0.1,0.225,0.225]));
        %tmp = randsample(-lb:1:ub,1);
        %offset(b) = tmp(offset_ix);
    end
    ept2 = ept1(:) + offset(:);    

    apt1 = repmat(pars.elev,1,pars.batchsize);
    apt2 = repmat(pars.elev,1,pars.batchsize);

    % Pitch and elevation of car 2.
    ept3 = randsample(12,pars.batchsize,1);
    ept4 = ept3(:) + offset(:);
    for b = 1:pars.batchsize,
        if ept4(b) < 1,
            inc = 1 - ept4(b);
            ept3(b) = ept3(b) + inc;
            ept4(b) = ept4(b) + inc;
        elseif ept4(b) > 12,
            dec = ept4(b) - 12;
            ept4(b) = ept4(b) - dec;
            ept3(b) = ept3(b) - dec;
        end
    end

    apt3 = repmat(pars.elev,1,pars.batchsize);
    apt4 = repmat(pars.elev,1,pars.batchsize);

    sz = [ 24, 4 ];

    ept1 = pars.pitches(ept1);
    ept2 = pars.pitches(ept2);
    ept3 = pars.pitches(ept3);
    ept4 = pars.pitches(ept4);

    idx1 = sub2ind(sz, ept1, apt1);
    idx2 = sub2ind(sz, ept2, apt2);
    idx3 = sub2ind(sz, ept3, apt3);
    idx4 = sub2ind(sz, ept4, apt4);

    imgs1 = single(data1.im) / 255.0;
    mask1 = data1.mask;
    imgs2 = single(data2.im) / 255.0;
    mask2 = data2.mask;

    batch_images.X1 = imgs1(:,:,:,idx1);
    batch_images.X2 = imgs1(:,:,:,idx2);
    batch_images.X3 = imgs2(:,:,:,idx3);
    batch_images.X4 = imgs2(:,:,:,idx4);
    batch_masks.X1 = single(mask1(:,:,idx1));
    batch_masks.X2 = single(mask1(:,:,idx2));
    batch_masks.X3 = single(mask2(:,:,idx3));
    batch_masks.X4 = single(mask2(:,:,idx4));

    batch_images.X1 = imresize(batch_images.X1, [64 64]);
    batch_images.X2 = imresize(batch_images.X2, [64 64]);
    batch_images.X3 = imresize(batch_images.X3, [64 64]);
    batch_images.X4 = imresize(batch_images.X4, [64 64]);

    batch_masks.X1 = round(imresize(batch_masks.X1, [64 64]));
    batch_masks.X2 = round(imresize(batch_masks.X2, [64 64]));
    batch_masks.X3 = round(imresize(batch_masks.X3, [64 64]));
    batch_masks.X4 = round(imresize(batch_masks.X4, [64 64]));

    batch_masks.X1 = reshape(batch_masks.X1, [64 64 1 pars.batchsize]);
    batch_masks.X2 = reshape(batch_masks.X2, [64 64 1 pars.batchsize]);
    batch_masks.X3 = reshape(batch_masks.X3, [64 64 1 pars.batchsize]);
    batch_masks.X4 = reshape(batch_masks.X4, [64 64 1 pars.batchsize]);
end

