
function train_sprites_manifold(use_gpu, transfer, masked, alpha, beta, sample_easy, expname, init_from)

if ~exist('use_gpu','var'),
    use_gpu = -1;
end

if ~exist('masked', 'var'),
    masked = 0;
end

if ~exist('alpha', 'var'),
    alpha = 10.0;
end

if ~exist('beta', 'var'),
    beta = 20.0;
end

if ~exist('sample_easy', 'var'),
    sample_easy = 0;
end

if ~exist('expname', 'var'),
    expname = 'deep';
end

if ~exist('transfer', 'var'),
    transfer = 1;
end

solver_enc = 'analogy_sprite_enc_solver.prototxt';
solver_dec = 'analogy_sprite_dec_solver.prototxt';

addpath('../../matlab/caffe');
addpath('../model');
addpath('../util');

subdir = 'sprites_manifold';
solver_enc = sprintf('results/%s/%s', subdir, solver_enc);
solver_dec = sprintf('results/%s/%s', subdir, solver_dec);

if exist('init_from', 'var'),
    matcaffe_init_sprites_train(use_gpu, solver_enc, solver_dec, init_from);
else
    matcaffe_init_sprites_train(use_gpu, solver_enc, solver_dec);
end

%% Load data.
datadir = '../data/sprites';
splits = load([datadir '/sprites_splits.mat']);

files = dir([datadir '/*.mat']);
tmp = files(splits.trainidx);
trainfiles = arrayfun(@(x) [ datadir '/' x.name ], tmp, 'UniformOutput', false);

%% Init model.
pars = struct('numstep', 200000, ...
              'alpha', alpha, ...
              'beta', beta, ...
              'sample_easy', sample_easy, ...
              'numhid', 512, ...
              'vdim', 60, ...
              'masked', masked, ...
              'transfer', transfer, ...
              'expname', expname, ...
              'lambda_mask', 1, ...
              'lambda_rgb', 10, ...
              'numfactor', 8, ...
              'batchsize', 25);
pars.fname = sprintf('analogy_sprite_manifold_a%.4f_b%.4f_m%d_e%d_%s', alpha, beta, masked, pars.sample_easy, datestr(now,30));
pars.fname_png = sprintf('images/%s/%s.png', subdir, pars.fname);
pars.fname_mat = sprintf('results/%s/%s.mat', subdir, pars.fname);
pars.card = [ 2, 4, 3, 6, 2, 2, 2 ];
num_categorical = 0;
for l = 1:length(pars.card),
    num_categorical = num_categorical + pars.card(l);
end
pars.card(7) = 3;
num_categorical = num_categorical + 1;
pars.num_categorical = num_categorical;
pars.id_idx = 1:pars.num_categorical;
pars.pose_idx = (pars.num_categorical+1):pars.numhid;
disp(pars);

%% Train.
caffe('set_phase_train');
for b = 1:pars.numstep,
    if pars.transfer == 1,
        [ batch_data, batch_masks, batch_labels ] = sample_analogy_transfer(trainfiles, pars);
    else
        [ batch_data, batch_masks, batch_labels ] = sample_analogy(trainfiles, pars);
    end

    Xin = { single(batch_data.X1); single(batch_data.X2); ...
            single(batch_data.X3); single(batch_data.X4)};

    % Encoder.
    result = caffe('forward', Xin, 0);
    out = squeeze(result{1});
    query = squeeze(result{2});
    ref = squeeze(result{3});
    target = squeeze(result{4});

    trans = out - ref;
    trans = trans(pars.pose_idx,:);
    if strcmp(expname, 'deep')
       top_pose = query(pars.pose_idx,:);
       sw = ones(size(trans), 'single');
    elseif strcmp(expname, 'vec')
       top_pose = query(pars.pose_idx,:) + trans;
       sw = zeros(size(trans), 'single');
    end 

    % Regularizers.
    [ cost_hid, deltas_hid, acc ] = ...
        grad_hid_disc_sprites(ref, out, query, target, batch_labels, pars);

    % Replace ID units with attributes.
    top_id = flatten(batch_labels.L3, pars);
    diff_id = flatten(batch_labels.L2, pars);

    % Decoder.
    result = caffe('forward', {trans, top_id, top_pose, sw, diff_id }, 1);
    hid_new = squeeze(result{1});
    pred_mask = result{2};
    pred_sprite = result{3};
    
    % Hidden unit regularizer.
    err_hid = hid_new - target;
    cost_hid = cost_hid + pars.beta*0.5*(err_hid(:)'*err_hid(:))/pars.batchsize;
    delta_hid = pars.beta*err_hid/pars.batchsize;

    % Reconstruction cost.
    [ cost_recon, delta_sprite, delta_mask ] = ...
        grad_recon(pred_sprite, pred_mask, batch_data.target, batch_masks.target, pars);

    % Update decoder.
    delta_recon = caffe('backward', { delta_hid; delta_mask; delta_sprite }, 1);
    caffe('update', 1);

    delta_trans = squeeze(delta_recon{1});
    %delta_id = squeeze(delta_recon{2});
    delta_pose = squeeze(delta_recon{3});
    
    % Combine gradients from KL and reconstruction.
    delta_ref =    deltas_hid{1};
    delta_ref(pars.pose_idx,:) = delta_ref(pars.pose_idx,:) - delta_trans;

    delta_out =    deltas_hid{2};
    delta_out(pars.pose_idx,:) = delta_out(pars.pose_idx,:) + delta_trans;

    delta_query =  deltas_hid{3};
    %delta_query(pars.id_idx,:) = delta_query(pars.id_idx,:) + delta_id;
    delta_query(pars.pose_idx,:) = delta_query(pars.pose_idx,:) + delta_pose;

    delta_target = deltas_hid{4} - delta_hid; 

    % Update encoder.
    caffe('backward', { delta_out; delta_query; delta_ref; delta_target }, 0);
    caffe('update', 0);

    % demo analogy.
    if mod(b,10)==0,
        fprintf(1, 'step %d of %d, cost_hid=%g, acc=%.4g costrecon=%g\n', ...
                b, pars.numstep, cost_hid, acc, cost_recon);
        pred = bsxfun(@times, pred_sprite, batch_masks.target);
        subplot(2,4,1);
        imagesc(permute(batch_data.X1(:,:,:,1), [ 1 2 3 4 ]));
        subplot(2,4,2);
        imagesc(permute(batch_data.X2(:,:,:,1), [ 1 2 3 4 ]));
        subplot(2,4,3);
        imagesc(permute(batch_data.X3(:,:,:,1), [ 1 2 3 4 ]));
        subplot(2,4,4);
        imagesc(permute(batch_data.X4(:,:,:,1), [ 1 2 3 4 ]));
        subplot(2,4,5); 
        imagesc(permute(mytrim(pred_sprite(:,:,:,1), 0, 1), [ 1 2 3 4 ]));
        subplot(2,4,6);
        imagesc(permute(mytrim(pred_mask(:,:,:,1), 0, 1), [ 1 2 3 4 ])); colormap gray;
        subplot(2,4,7);
        imagesc(permute(mytrim(pred(:,:,:,1), 0, 1), [ 1 2 3 4 ]));
        drawnow;
        print('-dpng', pars.fname_png);
    end

    if mod(b,1000)==0,
        model = struct('enc', caffe('get_weights', 0), ...
                       'dec', caffe('get_weights', 1), ...
                       'pars', pars);
        save(pars.fname_mat, '-struct', 'model');        
    end
end

end

function [ batch_sprites, batch_masks, batch_labels ] = ...
    sample_analogy_transfer(files, pars)

  batch_sprites = struct;
  batch_masks = struct;
  batch_labels = struct;

  % randomly pick two characters.
  idx = randsample(size(files,1),2);
  idx1 = idx(1);
  idx2 = idx(2);

  data1 = load(files{idx1});
  data2 = load(files{idx2});
  batch_labels.L1 = data1.labels;
  batch_labels.L2 = data1.labels;
  batch_labels.L3 = data2.labels;
  batch_labels.L4 = data2.labels;

  anim1_sprite = data1.sprites;
  anim1_mask = data1.masks;
  anim2_sprite = data2.sprites;
  anim2_mask = data2.masks;

  % randomly pick one of the animations.
  idx_anim = randsample(length(anim1_sprite)-1,1);

  batch_labels.L1 = fix_wpn(batch_labels.L1, idx_anim);
  batch_labels.L2 = fix_wpn(batch_labels.L2, idx_anim);
  batch_labels.L3 = fix_wpn(batch_labels.L3, idx_anim);
  batch_labels.L4 = fix_wpn(batch_labels.L4, idx_anim);

  sprite1 = anim1_sprite{idx_anim};
  mask1 = anim1_mask{idx_anim};
  sprite2 = anim2_sprite{idx_anim};
  mask2 = anim2_mask{idx_anim};

  t1_idx = randsample(1:size(mask1,2), pars.batchsize, 1);
  t2_idx = t1_idx(randperm(numel(t1_idx)));

  batch_sprites.X1 = sprite1(:,t1_idx);
  batch_sprites.X2 = sprite1(:,t2_idx);
  batch_sprites.X3 = sprite2(:,t1_idx);
  batch_sprites.X4 = sprite2(:,t2_idx);

  batch_masks.X1 = mask1(:,t1_idx);
  batch_masks.X2 = mask1(:,t2_idx);
  batch_masks.X3 = mask2(:,t1_idx);
  batch_masks.X4 = mask2(:,t2_idx);

  batch_sprites.X1 = reshape([ batch_sprites.X1 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X2 = reshape([ batch_sprites.X2 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X3 = reshape([ batch_sprites.X3 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X4 = reshape([ batch_sprites.X4 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.target = batch_sprites.X4;

  batch_masks.X1 = reshape([ batch_masks.X1 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X2 = reshape([ batch_masks.X2 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X3 = reshape([ batch_masks.X3 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X4 = reshape([ batch_masks.X4 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.target = batch_masks.X4;
end

function [ batch_sprites, batch_masks, batch_labels ] = ...
    sample_analogy(files, pars)

  batch_sprites = struct;
  batch_masks = struct;
  batch_labels = struct;

  % randomly pick two characters.
  idx = randsample(size(files,1),2);
  idx1 = idx(1);
  idx2 = idx(2);

  data1 = load(files{idx1});
  data2 = load(files{idx2});
  batch_labels.L1 = data1.labels;
  batch_labels.L2 = data1.labels;
  batch_labels.L3 = data2.labels;
  batch_labels.L4 = data2.labels;

  anim1_sprite = data1.sprites;
  anim1_mask = data1.masks;
  anim2_sprite = data2.sprites;
  anim2_mask = data2.masks;
  
  if rand() > 0.3, % sample from existing animations
    % randomly pick one of the animations.
    idx_anim1 = randsample(length(anim1_sprite)-1,1);
    idx_anim2 = randsample(length(anim1_sprite)-1,1);

    if pars.sample_easy,
      % Only sample from the same animation group.
      group = ceil(idx_anim1/4);
      idx_anim2 = randsample(1:4,1) + (group-1)*4;
    end

    batch_labels.L1 = fix_wpn(batch_labels.L1, idx_anim1);
    batch_labels.L2 = fix_wpn(batch_labels.L2, idx_anim1);
    batch_labels.L3 = fix_wpn(batch_labels.L3, idx_anim2);
    batch_labels.L4 = fix_wpn(batch_labels.L4, idx_anim2);

    sprite1 = anim1_sprite{idx_anim1};
    mask1 = anim1_mask{idx_anim1};
    sprite2 = anim2_sprite{idx_anim2};
    mask2 = anim2_mask{idx_anim2};
    sprite3 = anim1_sprite{idx_anim2}; % ID from out, pose from query.
    mask3 = anim1_mask{idx_anim2};

    num_frames = size(mask1,2);
    t1_idx = randsample(1:num_frames, pars.batchsize, 1);
    offset = randsample(-1:1, pars.batchsize,1);
    t2_idx = t1_idx + offset;
    %out of bound
    bad_idx = (t2_idx > num_frames) | (t2_idx < 1);
    offset(bad_idx) = 0;
    t2_idx = t1_idx + offset;
   
    num_frames = size(mask2,2);
    t3_idx = randsample(1:num_frames,pars.batchsize,1);
    if pars.sample_easy,
      t3_idx = t1_idx;
    end
    badlower = (t3_idx==1)&(offset==-1);
    badupper = (t3_idx==num_frames)&(offset==1); 
    t3_idx(badlower) = randsample(2:num_frames,sum(badlower),1);
    t3_idx(badupper) = randsample(1:(num_frames-1),sum(badupper),1);
    t4_idx = t3_idx + offset;

    batch_sprites.X1 = sprite1(:,t1_idx);
    batch_sprites.X2 = sprite1(:,t2_idx);
    batch_sprites.X3 = sprite2(:,t3_idx);
    batch_sprites.X4 = sprite2(:,t4_idx);
    batch_sprites.X5 = sprite3(:,t3_idx); % id of X2, pose of X3.

    batch_masks.X1 = mask1(:,t1_idx);
    batch_masks.X2 = mask1(:,t2_idx);
    batch_masks.X3 = mask2(:,t3_idx);
    batch_masks.X4 = mask2(:,t4_idx);
    batch_masks.X5 = mask3(:,t3_idx);
  else % sample rotations with probability 0.1
    idx_action = 4*(randsample(1:5, 2)-1);

    anim1_idx = randsample(1:4, pars.batchsize, 1);
    offset = randsample(-1:1, pars.batchsize,1);
    anim2_idx = anim1_idx + offset;
    anim2_idx(anim2_idx>4) = anim2_idx(anim2_idx>4) - 4;
    anim2_idx(anim2_idx<1) = anim2_idx(anim2_idx<1) + 4;

    anim3_idx = randsample(1:4, pars.batchsize, 1);
    anim4_idx = anim3_idx + offset;
    anim4_idx(anim4_idx>4) = anim4_idx(anim4_idx>4) - 4;
    anim4_idx(anim4_idx<1) = anim4_idx(anim4_idx<1) + 4;

    batch_labels.L1 = fix_wpn(batch_labels.L1, idx_action(1)+1);
    batch_labels.L2 = fix_wpn(batch_labels.L2, idx_action(1)+1);
    batch_labels.L3 = fix_wpn(batch_labels.L3, idx_action(2)+1);
    batch_labels.L4 = fix_wpn(batch_labels.L4, idx_action(2)+1);

    for i=1:pars.batchsize, 
      sprite1 = anim1_sprite{idx_action(1)+anim1_idx(i)};
      mask1 = anim1_mask{idx_action(1)+anim1_idx(i)};
      sprite2 = anim1_sprite{idx_action(1)+anim2_idx(i)};
      mask2 = anim1_mask{idx_action(1)+anim2_idx(i)};

      sprite3 = anim2_sprite{idx_action(2)+anim3_idx(i)};
      mask3 = anim2_mask{idx_action(2)+anim3_idx(i)};
      sprite4 = anim2_sprite{idx_action(2)+anim4_idx(i)};
      mask4 = anim2_mask{idx_action(2)+anim4_idx(i)};

      % ID of sprite2, pose of sprite3.
      sprite5 = anim1_sprite{idx_action(2)+anim3_idx(i)};
      mask5 = anim1_mask{idx_action(2)+anim3_idx(i)};
     
      t1_idx = randsample(1:size(mask1,2),1); 
      t2_idx = randsample(1:size(mask3,2),1); 
      batch_sprites.X1(:,i) = sprite1(:,t1_idx);
      batch_sprites.X2(:,i) = sprite2(:,t1_idx);
      batch_sprites.X3(:,i) = sprite3(:,t2_idx);
      batch_sprites.X4(:,i) = sprite4(:,t2_idx);
      batch_sprites.X5(:,i) = sprite5(:,t2_idx);

      batch_masks.X1(:,i) = mask1(:,t1_idx);
      batch_masks.X2(:,i) = mask2(:,t1_idx);
      batch_masks.X3(:,i) = mask3(:,t2_idx);
      batch_masks.X4(:,i) = mask4(:,t2_idx);
      batch_masks.X5(:,i) = mask5(:,t2_idx);
    end
  end

  batch_sprites.X1 = reshape([ batch_sprites.X1 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X2 = reshape([ batch_sprites.X2 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X3 = reshape([ batch_sprites.X3 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X4 = reshape([ batch_sprites.X4 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X5 = reshape([ batch_sprites.X5 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);

  batch_masks.X1 = reshape([ batch_masks.X1 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X2 = reshape([ batch_masks.X2 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X3 = reshape([ batch_masks.X3 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X4 = reshape([ batch_masks.X4 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X5 = reshape([ batch_masks.X5 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);

  batch_sprites.target = cat(4, batch_sprites.X4, batch_sprites.X5);
  batch_masks.target = cat(4, batch_masks.X4, batch_masks.X5);
end

function L = fix_wpn(L, idx_anim)
  if L(7)==0, % bow
      if ((idx_anim >= 17) && (idx_anim <= 20)),
          L(7) = 1;
      else
          L(7) = 0;
      end
  elseif L(7)==1, % spear
      if ((idx_anim >= 5) && (idx_anim <= 12)),
          L(7) = 2;
      else
          L(7) = 0;
      end    
  end
end

function L = flatten(Lin, pars)
    idx = 1;
    Lin = Lin + 1;
    L = zeros(sum(pars.card),pars.batchsize,'single');
    for l = 1:length(pars.card),
        curidx = idx:(idx + pars.card(l) - 1);
        L(curidx,:) = repmat(oneofc(Lin(l), pars.card(l)), 1, pars.batchsize);
        idx = idx + pars.card(l);
    end
end
