function train_shapes(gpu, expname, numstep, alpha, init_from)

script_dir = fileparts(mfilename('fullpath'));

addpath(fullfile(script_dir ,'../../matlab/caffe'));
addpath(fullfile(script_dir ,'../util'));
addpath(fullfile(script_dir ,'../model'));

if ~exist('gpu', 'var'),
    gpu = -1;
end

if ~exist('expname', 'var')
    expname = 'deep';
end

if ~exist('numstep', 'var'),
    numstep = 100000;
end

if ~exist('alpha', 'var'),
    alpha = 0.01;
end

switch(expname)
  case 'add'
    % h = f(b) - f(a) + f(c)
    sampler = @(data, pars) sample_analogy(data, pars);
    update = @(data, pars) update_shapes_add(data, pars);
  case 'mul'
    % h = f(c) + W x1 [f(b) - f(a)] x2 f(c)
    sampler = @(data, pars) sample_analogy(data, pars);
    update = @(data, pars) update_shapes_mul(data, pars);
  case 'deep'
    % h = f(c) + W x1 [f(b) - f(a)] x2 f(c)
    sampler = @(data, pars) sample_analogy(data, pars);
    update = @(data, pars) update_shapes_deep(data, pars);
end

subdir = sprintf('shapes_%s', expname);
imgdir = sprintf('images/%s', subdir);
if ~exist(imgdir, 'dir'),
    mkdir(imgdir);
end
solver_enc = 'analogy_enc_solver.prototxt';
solver_enc = sprintf('results/%s/%s', subdir, solver_enc);
solver_dec = 'analogy_dec_solver.prototxt';
solver_dec = sprintf('results/%s/%s', subdir, solver_dec);

if exist('init_from','var'),
    matcaffe_init_shapes_train(gpu, solver_enc, solver_dec, init_from);    
else
    matcaffe_init_shapes_train(gpu, solver_enc, solver_dec);
end

%% Generate data.
data_cache_fn = '../data/shapes48.mat';
if ~exist( data_cache_fn, 'file' )
    mkdir_p( fileparts(data_cache_fn) );
    fprintf( 'Generate shape data: ' ); tic
    [M,L] = gen_shapes_data;
    save(data_cache_fn,'M','L');
    toc
end
load(data_cache_fn);
data = M;

% Make training pairs of ids.
rng('default');
rng(1);
numid = size(data,4)*size(data,5);
pairs = eye(numid);
numtrain = 800;
numtest = 224;
R = randsample(numel(pairs), numtrain);
pairs(R) = 1;
[tmp1,tmp2] = find(pairs);
trainpairs = [ tmp1, tmp2 ];
save('shapes48_splits.mat', 'pairs', 'trainpairs');
splits = load('shapes48_splits.mat');

data = single(data);

%% Init model.
pars = struct('numstep', numstep, ...
              'numhid', 512, ...
              'vdim', 48, ...
              'alpha', alpha, ...
              'expname', expname, ...
              'batchsize', 25);
pars.fname = sprintf('analogy_shapes_%s_a%.2g_%s', expname, pars.alpha, datestr(now,30));
pars.fname_png = sprintf('images/%s/%s.png', subdir, pars.fname);
pars.fname_mat = sprintf('results/%s/%s.mat', subdir, pars.fname);
pars.fname_mat_final = sprintf('results/%s/analogy_shapes_%s.mat', subdir, expname);
pars.solver_enc = solver_enc;
pars.solver_dec = solver_dec;
pars.pairs = splits.pairs;
pars.trainpairs = splits.trainpairs;
disp(pars);

caffe('set_phase_train');
for b = 1:(pars.numstep+1),
    [ batch_data ] = sampler(data, pars);
    [ pred_data, cost_recon ] = update(batch_data, pars);

    % demo analogy.
    if mod(b-1,100)==0,
        fprintf(1, 'step %d of %d, costrecon=%g\n', ...
                b, pars.numstep, cost_recon);
        subplot(1,4,1);
        imagesc(batch_data.X1(:,:,:,1));
        subplot(1,4,2);
        imagesc(batch_data.X2(:,:,:,1));
        subplot(1,4,3);
        imagesc(batch_data.X3(:,:,:,1));
        subplot(1,4,4);
        imagesc(mytrim(pred_data(:,:,:,1), 0, 1));
        drawnow;
        print('-dpng', pars.fname_png);
    end

    if mod(b,10000)==0,
        model = struct('enc', caffe('get_weights', 0), ...
                       'dec', caffe('get_weights', 1), ...
                       'pars', pars);
        save(pars.fname_mat, '-struct', 'model');
        system( sprintf( 'cp "%s" "%s"' , pars.fname_mat, pars.fname_mat_final ) );
    end
end

end

function [ pred_data, cost_recon ] = update_shapes_add(batch_data, pars)
    Xin = { single(batch_data.X1); ...
            single(batch_data.X2); ...
            single(batch_data.X3); ...
            single(batch_data.X4) };

    % Encoder.
    result = caffe('forward', Xin, 0);
    out = squeeze(result{1});
    query = squeeze(result{2});
    ref = squeeze(result{3});
    target = squeeze(result{4});

    top = out - ref + query;

    % Hidden unit cost.
    err_hid = top - target;
    cost_hid = pars.alpha*0.5*(err_hid(:)'*err_hid(:));
    delta_hid = pars.alpha*err_hid;    
    
    % Decoder.
    result = caffe('forward', { top }, 1);
    pred_data = result{1};

    [ cost_recon, delta_data ] = ...
        grad_recon_img(pred_data, batch_data.X4, pars);
    delta_recon = caffe('backward', { delta_data }, 1);
    caffe('update', 1);
    delta_recon = squeeze(delta_recon{1});
    delta_out =      delta_recon + delta_hid;
    delta_query =    delta_recon + delta_hid;
    delta_ref =     -delta_recon - delta_hid;
    delta_target =  -delta_hid;

    % Update encoder.
    caffe('backward', { delta_out; delta_query; delta_ref; delta_target }, 0);
    caffe('update', 0);
    
    cost = cost_hid + cost_recon;
end

function [ pred_data, cost ] = update_shapes_deep(batch_data, pars)
    Xin = { single(batch_data.X1); ...
            single(batch_data.X2); ...
            single(batch_data.X3); ...
            single(batch_data.X4) };

    % Encoder.
    result = caffe('forward', Xin, 0);
    out = squeeze(result{1});
    query = squeeze(result{2});
    ref = squeeze(result{3});
    target = squeeze(result{4});

    trans = out - ref;
    hid = query;

    % Decoder.
    result = caffe('forward', { trans, hid }, 1);
    hid_new = squeeze(result{1});
    pred_data = result{2};
    
    % Hidden unit cost.
    err_hid = hid_new - target;
    cost_hid = pars.alpha*0.5*(err_hid(:)'*err_hid(:));
    delta_hid = pars.alpha*err_hid;

    [ cost_recon, delta_data ] = ...
        grad_recon_img(pred_data, batch_data.X4, pars);
    delta_recon = caffe('backward', { delta_hid; delta_data }, 1);
    caffe('update', 1);
    delta_trans = squeeze(delta_recon{1});
    delta_top = squeeze(delta_recon{2});
    delta_out =      delta_trans;
    delta_query =    delta_top;
    delta_ref =     -delta_trans;
    delta_target =  -delta_hid;

    % Update encoder.
    caffe('backward', { delta_out; delta_query; delta_ref; delta_target }, 0);
    caffe('update', 0);
    
    cost = cost_recon + cost_hid;
end

function [ pred_data, cost ] = update_shapes_mul(batch_data, pars)
    Xin = { single(batch_data.X1); ...
            single(batch_data.X2); ...
            single(batch_data.X3); ...
            single(batch_data.X4) };

    % Encoder.
    result = caffe('forward', Xin, 0);
    out = squeeze(result{1});
    query = squeeze(result{2});
    ref = squeeze(result{3});
    target = squeeze(result{4});

    trans = out - ref;
    hid = query;

    % Decoder.
    result = caffe('forward', { trans, hid }, 1);
    hid_new = squeeze(result{1});
    pred_data = result{2};

    % Hidden unit cost.
    err_hid = hid_new - target;
    cost_hid = pars.alpha*0.5*(err_hid(:)'*err_hid(:));
    delta_hid = pars.alpha*err_hid;

    [ cost_recon, delta_data ] = ...
        grad_recon_img(pred_data, batch_data.X4, pars);
    delta_recon = caffe('backward', { delta_hid; delta_data }, 1);
    caffe('update', 1);
    delta_trans = squeeze(delta_recon{1});
    delta_top = squeeze(delta_recon{2});
    delta_out =      delta_trans;
    delta_query =    delta_top;
    delta_ref =     -delta_trans;
    delta_target =  -delta_hid;

    % Update encoder.
    caffe('backward', { delta_out; delta_query; delta_ref; delta_target }, 0);
    caffe('update', 0);
    
    cost = cost_recon + cost_hid;
end

function [ batch_images ] = sample_analogy(data, pars)
    batch_images = struct;

    % randomly pick ids from list of training set pairs.
    idx = randsample(size(pars.trainpairs,1),pars.batchsize);
    id1 = pars.trainpairs(idx,1);
    id2 = pars.trainpairs(idx,2);

    % Randomly pick factor of variation to change (angle, scale, xpos, ypos).
    tochange = randsample(4,1);

    % Sample default angle, scale, xpos ypos (to be potentially
    % overwritten in the following).
    % One for the input pair, and one for the output pair.
    numscale = size(data,6);
    numangle = size(data,7);
    numxpos = size(data,8);
    numypos = size(data,9);
    angle_default1 = randsample(numangle,pars.batchsize,1);
    scale_default1 = randsample(numscale,pars.batchsize,1);
    xpos_default1 = randsample(numxpos,pars.batchsize,1);
    ypos_default1 = randsample(numypos,pars.batchsize,1);
    angle_default2 = randsample(numangle,pars.batchsize,1);
    scale_default2 = randsample(numscale,pars.batchsize,1);
    xpos_default2 = randsample(numxpos,pars.batchsize,1);
    ypos_default2 = randsample(numypos,pars.batchsize,1);

    % assign defaults to analogy particles 1,2,3,4.
    angle1 = angle_default1;
    angle2 = angle_default1;
    angle3 = angle_default2;
    angle4 = angle_default2;
    scale1 = scale_default1;
    scale2 = scale_default1;
    scale3 = scale_default2;
    scale4 = scale_default2;
    xpos1 = xpos_default1;
    xpos2 = xpos_default1;
    xpos3 = xpos_default2;
    xpos4 = xpos_default2;
    ypos1 = ypos_default1;
    ypos2 = ypos_default1;
    ypos3 = ypos_default2;
    ypos4 = ypos_default2;

    if tochange == 1, % angle
      % randomly pick rotation offset (20% chance of being 0).
      offset = randsample(-2:2,pars.batchsize,1);
      angle1 = randsample(numangle,pars.batchsize,1);
      angle2 = angle1 + offset';
      angle2(angle2 < 1) = numangle + angle2(angle2 < 1);
      angle2(angle2 > numangle) = angle2(angle2 > numangle) - numangle;

      angle3 = randsample(numangle,pars.batchsize,1);
      angle4 = angle3 + offset';
      angle4(angle4 < 1) = numangle + angle4(angle4 < 1);
      angle4(angle4 > numangle) = angle4(angle4 > numangle) - numangle;
    elseif tochange == 2, % scale
      offset = randsample(-1:1,pars.batchsize,1)';
      scale1 = randsample(numscale,pars.batchsize,1);
      scale2 = scale1(:) + offset(:);

      % Handle boundary cases.
      badidx = (scale2<1) | (scale2>numscale);
      offset(badidx) = -1*offset(badidx);
      scale2(badidx) = scale1(badidx) + offset(badidx);

      scale3 = randsample(numscale,pars.batchsize,1);
      badlower = (scale3==1)&(offset==-1);
      badupper = (scale3==numscale)&(offset==1);
      scale3(badlower) = randsample(2:numscale,sum(badlower),1);
      scale3(badupper) = randsample(1:(numscale-1),sum(badupper),1);
      scale4 = scale3(:) + offset(:);
    elseif tochange == 3, % xpos
      offset = randsample(-1:1,pars.batchsize,1)';
      xpos1 = randsample(numxpos,pars.batchsize,1);
      xpos2 = xpos1(:) + offset(:);

      % Handle boundary cases.
      badidx = (xpos2<1) | (xpos2>numxpos);
      offset(badidx) = -1*offset(badidx);
      xpos2(badidx) = xpos1(badidx) + offset(badidx);

      xpos3 = randsample(numxpos,pars.batchsize,1);
      badlower = (xpos3==1)&(offset==-1);
      badupper = (xpos3==numxpos)&(offset==1);
      xpos3(badlower) = randsample(2:numxpos,sum(badlower),1);
      xpos3(badupper) = randsample(1:(numxpos-1),sum(badupper),1);
      xpos4 = xpos3(:) + offset(:);
    else % ypos
      offset = randsample(-1:1,pars.batchsize,1)';
      ypos1 = randsample(numypos,pars.batchsize,1);
      ypos2 = ypos1(:) + offset(:);

      % Handle boundary cases.
      badidx = (ypos2<1) | (ypos2>numypos);
      offset(badidx) = -1*offset(badidx);
      ypos2(badidx) = ypos1(badidx) + offset(badidx);

      ypos3 = randsample(numypos,pars.batchsize,1);
      badlower = (ypos3==1)&(offset==-1);
      badupper = (ypos3==numypos)&(offset==1);
      ypos3(badlower) = randsample(2:numypos,sum(badlower),1);
      ypos3(badupper) = randsample(1:(numypos-1),sum(badupper),1);
      ypos4 = ypos3(:) + offset(:);
    end

    numcol = size(data,4);
    numshape = size(data,5);
    [ col1, shape1 ] = ind2sub([numcol, numshape], id1);
    [ col2, shape2 ] = ind2sub([numcol, numshape], id2);

    sz = size(data);
    sz = sz(4:end);
    idx1 = sub2ind(sz, col1, shape1, scale1, angle1, xpos1, ypos1);
    idx2 = sub2ind(sz, col1, shape1, scale2, angle2, xpos2, ypos2);
    idx3 = sub2ind(sz, col2, shape2, scale3, angle3, xpos3, ypos3);
    idx4 = sub2ind(sz, col2, shape2, scale4, angle4, xpos4, ypos4);

    batch_images.X1 = data(:,:,:,idx1);
    batch_images.X2 = data(:,:,:,idx2);
    batch_images.X3 = data(:,:,:,idx3);
    batch_images.X4 = data(:,:,:,idx4);
    batch_images.target = batch_images.X4;
end

