function train_sprites(gpu, expname, masked, alpha, init_from)


if ~exist('gpu', 'var'),
    gpu = 3;
end

if ~exist('expname','var')
    expname = 'dis';
end


if ~exist('masked', 'var'),
    masked = 0;
end

if ~exist('alpha', 'var'),
    alpha = 10.0;
end

switch(expname)
  case 'add'
    sampler = @(files, pars) sample_analogy_add(files, pars);
    update = @(batch_data, batch_masks, batch_labels, pars) update_sprites_add(batch_data, batch_masks, batch_labels, pars);
  case 'dis'
    sampler = @(files, pars) sample_analogy_dis(files, pars);
    update = @(batch_data, batch_masks, batch_labels, pars) update_sprites_dis(batch_data, batch_masks, batch_labels, pars);
  case 'add+cls'
    sampler = @(files, pars) sample_analogy_add(files, pars);
    update = @(batch_data, batch_masks, batch_labels, pars) update_sprites_add_cls(batch_data, batch_masks, batch_labels, pars);
  case 'dis+cls'
    % h = f(c) + W x1 [f(b) - f(a)] x2 f(c)
    sampler = @(files, pars) sample_analogy_dis(files, pars);
    update = @(batch_data, batch_masks, batch_labels, pars) update_sprites_dis_cls(batch_data, batch_masks, batch_labels, pars);
end

solver_enc = 'analogy_sprite_enc_solver.prototxt';
solver_dec = 'analogy_sprite_dec_solver.prototxt';

addpath('../util');
addpath('../model');
addpath('../../matlab/caffe');

if strcmp(expname, 'add+cls')
  subdir = 'sprites_add';
elseif strcmp(expname, 'dis+cls')
  subdir = 'sprites_dis';
else
  subdir = sprintf('sprites_%s', expname);
end
solver_enc = sprintf('results/%s/%s', subdir, solver_enc);
solver_dec = sprintf('results/%s/%s', subdir, solver_dec);

if exist('init_from', 'var'),
    matcaffe_init_sprites_train(gpu, solver_enc, solver_dec, init_from);
else
    matcaffe_init_sprites_train(gpu, solver_enc, solver_dec);
end

%% Load data.
datadir = '../data/sprites';
splits = load(sprintf('%s/sprites_splits.mat', datadir));

files = dir([datadir '/*.mat']);
tmp = files(splits.trainidx);
trainfiles = arrayfun(@(x) [ datadir '/' x.name ], tmp, 'UniformOutput', false);

%% Init model.
pars = struct('numstep', 200000, ...
              'alpha', alpha, ...
              'numhid', 256, ...
              'vdim', 60, ...
              'masked', masked, ...
              'lambda_mask', 1, ...
              'lambda_rgb', 10, ...
              'numfactor', 8, ...
              'batchsize', 25);
pars.fname = sprintf('analogy_sprite_large_dis_a%.4g_m%d_%s', alpha, masked, datestr(now,30));
pars.fname_png = sprintf('images/%s/%s.png', subdir, pars.fname);
mkdir(sprintf('images/%s', subdir))
pars.fname_mat = sprintf('results/%s/%s.mat', subdir, pars.fname);
pars.card = [ 2, 4, 3, 6, 2, 2, 2 ];
num_categorical = 0;
for l = 1:length(pars.card),
    num_categorical = num_categorical + pars.card(l);
end
pars.card(7) = 3;
pars.num_categorical = num_categorical + 1;
if strcmp(expname, 'dis+cls')
  pars.id_idx = 1:pars.num_categorical;
  pars.pose_idx = (pars.num_categorical+1):pars.numhid;
else
  pars.id_idx = 1:pars.numhid/2;
  pars.pose_idx = (pars.numhid/2+1):pars.numhid;
end

%% Train.
caffe('set_phase_train');
for b = 1:pars.numstep,
    [ batch_data, batch_masks, batch_labels ] = sampler(trainfiles, pars);
    [ pred_sprites, pred_masks, cost_recon ] = update(batch_data, batch_masks, batch_labels, pars); 

    % demo analogy.
    if mod(b,10)==0,
        fprintf(1, 'step %d of %d, costrecon=%g\n', ...
                b, pars.numstep, cost_recon);
        pred = bsxfun(@times, pred_sprites, pred_masks);
        subplot(2,3,1);
        imagesc(permute(batch_data.X1(:,:,:,1), [ 1 2 3 4 ]));
        subplot(2,3,2);
        imagesc(permute(batch_data.X2(:,:,:,1), [ 1 2 3 4 ]));
        subplot(2,3,3);
        imagesc(permute(batch_data.X3(:,:,:,1), [ 1 2 3 4 ]));
        subplot(2,3,4);
        imagesc(permute(mytrim(pred_sprites(:,:,:,1), 0, 1), [ 1 2 3 4 ]));
        subplot(2,3,5);
        imagesc(permute(mytrim(pred_masks(:,:,:,1), 0, 1), [ 1 2 3 4 ])); colormap gray;
        subplot(2,3,6);
        imagesc(permute(mytrim(pred(:,:,:,1), 0, 1), [ 1 2 3 4 ]));
        drawnow;
        print('-dpng', pars.fname_png);
    end

    if mod(b,1000)==0,
        model = struct('enc', caffe('get_weights', 0), ...
                       'dec', caffe('get_weights', 1), ...
                       'pars', pars);
        save(pars.fname_mat, '-struct', 'model');        
    end
end

end

function [ pred_sprites, pred_masks, cost_recon ] = update_sprites_add(batch_data, batch_masks, batch_labels, pars) 
    Xin = { single(batch_data.X1); single(batch_data.X2); ...
            single(batch_data.X3); single(batch_data.X4)};

    % Encoder.
    result = caffe('forward', Xin, 0);
    out = squeeze(result{1});
    query = squeeze(result{2});
    ref = squeeze(result{3});
    target = squeeze(result{4});

    % Vector addition for analogy
    top = out - ref + query;

    % Decoder.
    result = caffe('forward', { top }, 1);
    pred_masks = result{1};
    pred_sprites = result{2};

    % Reconstruction cost.
    [ cost_recon, delta_sprite, delta_mask ] = ...
        grad_recon(pred_sprites, pred_masks, batch_data.X4, batch_masks.X4, pars);

    % Update decoder.
    delta_recon = caffe('backward', { delta_mask; delta_sprite }, 1);
    delta_recon = squeeze(delta_recon{1});
    caffe('update', 1);
    
    % Zero out first part of delta_recon;
    delta_recon(1:pars.num_categorical,:) = 0;

    % Combine gradients from KL and reconstruction.
    delta_out =     delta_recon;
    delta_query =  delta_recon;
    delta_ref =    - delta_recon;
    delta_target = 0*target; 

    % Update encoder.
    caffe('backward', { delta_out; delta_query; delta_ref; delta_target }, 0);
    caffe('update', 0);
end

function [ pred_sprites, pred_masks, cost_recon ] = update_sprites_add_cls(batch_data, batch_masks, batch_labels, pars) 
    Xin = { single(batch_data.X1); single(batch_data.X2); ...
            single(batch_data.X3); single(batch_data.X4)};

    % Encoder.
    result = caffe('forward', Xin, 0);
    out = squeeze(result{1});
    query = squeeze(result{2});
    ref = squeeze(result{3});
    target = squeeze(result{4});

    % Vector addition for analogy
    top = out - ref + query;

    % Regularizers.
    [ cost_hid, deltas_hid, acc ] = ...
        grad_sprites_hid_disc(ref, out, query, target, batch_labels, pars);

    % Decoder.
    result = caffe('forward', { top }, 1);
    pred_masks = result{1};
    pred_sprites = result{2};

    % Reconstruction cost.
    [ cost_recon, delta_sprite, delta_mask ] = ...
        grad_recon(pred_sprites, pred_masks, batch_data.X4, batch_masks.X4, pars);

    % Update decoder.
    delta_recon = caffe('backward', { delta_mask; delta_sprite }, 1);
    delta_recon = squeeze(delta_recon{1});
    caffe('update', 1);
    
    % Zero out first part of delta_recon;
    delta_recon(1:pars.num_categorical,:) = 0;

    % Combine gradients from regularizer and reconstruction.
    delta_out =    deltas_hid{1} + delta_recon;
    delta_query =  deltas_hid{2} + delta_recon;
    delta_ref =    deltas_hid{3} - delta_recon;
    delta_target = deltas_hid{4}; 

    % Update encoder.
    caffe('backward', { delta_out; delta_query; delta_ref; delta_target }, 0);
    caffe('update', 0);
end

function [ pred_sprites, pred_masks, cost_recon ] = update_sprites_dis(batch_data, batch_masks, batch_labels, pars) 
    Xin = { single(batch_data.X1); single(batch_data.X2); ...
            single(batch_data.X3)};
    % Encoder.
    result = caffe('forward', Xin, 0);
    hid2 = squeeze(result{1});
    hid3 = squeeze(result{2});
    hid1 = squeeze(result{3});

    top = hid2;
    top(pars.id_idx,:) = hid1(pars.id_idx, :);

    % Decoder.
    result = caffe('forward', { top }, 1);
    pred_masks = result{1};
    pred_sprites = result{2};

    % Reconstruction cost.
    [ cost_recon, delta_sprite, delta_mask ] = ...
        grad_recon(pred_sprites, pred_masks, batch_data.X3, batch_masks.X3, pars);

    % Update decoder.
    delta_recon = caffe('backward', { delta_mask; delta_sprite }, 1);
    delta_recon = squeeze(delta_recon{1});
    caffe('update', 1);

    delta_hid1 = 0*hid1;
    delta_hid2 = 0*hid2;
    delta_hid3 = 0*hid3;
    delta_hid1(pars.id_idx,:) = delta_recon(pars.id_idx,:);
    delta_hid2(pars.pose_idx,:) = delta_recon(pars.pose_idx,:);

    % Update encoder.
    caffe('backward', { delta_hid2; delta_hid3; delta_hid1 }, 0);
    caffe('update', 0);
end

function [ pred_sprites, pred_masks, cost_recon ] = update_sprites_dis_cls(batch_data, batch_masks, batch_labels, pars) 
    Xin = { single(batch_data.X1); single(batch_data.X2); ...
            single(batch_data.X3)};
    % Encoder.
    result = caffe('forward', Xin, 0);
    hid2 = squeeze(result{1});
    hid3 = squeeze(result{2});
    hid1 = squeeze(result{3});

     % Regularizers.
    [ cost_hid, deltas_hid, acc ] = ...
        grad_hid_disc2(hid1, hid2, batch_labels.L1, batch_labels.L2, pars);

    % Replace ID units with attributes.
    top = hid2;
    L3 = flatten(batch_labels.L3, pars);
    top(pars.id_idx,:) = L3;

    % Decoder.
    result = caffe('forward', { top }, 1);
    pred_masks = result{1};
    pred_sprites = result{2};

    % Reconstruction cost.
    [ cost_recon, delta_sprite, delta_mask ] = ...
        grad_recon(pred_sprites, pred_masks, batch_data.X3, batch_masks.X3, pars);

    % Update decoder.
    delta_recon = caffe('backward', { delta_mask; delta_sprite }, 1);
    delta_recon = squeeze(delta_recon{1});
    caffe('update', 1);

    delta_hid1 = 0*hid1;
    delta_hid2 = 0*hid2;
    delta_hid3 = 0*hid3;
    delta_hid1(pars.id_idx,:) = delta_recon(pars.id_idx,:);
    delta_hid2(pars.pose_idx,:) = delta_recon(pars.pose_idx,:);

    % Update encoder.
    caffe('backward', { delta_hid2; delta_hid3; delta_hid1 }, 0);
    caffe('update', 0);
end

function [ batch_sprites, batch_masks, batch_labels ] = ...
    sample_analogy_add(files, pars)

  batch_sprites = struct;
  batch_masks = struct;
  batch_labels = struct;

  % randomly pick two characters.
  idx = randsample(size(files,1),2);
  idx1 = idx(1);
  idx2 = idx(2);

  data1 = load(files{idx1});
  data2 = load(files{idx2});
  batch_labels.L1 = data1.labels;
  batch_labels.L2 = data1.labels;
  batch_labels.L3 = data2.labels;
  batch_labels.L4 = data2.labels;

  anim1_sprite = data1.sprites;
  anim1_mask = data1.masks;
  anim2_sprite = data2.sprites;
  anim2_mask = data2.masks;

  % randomly pick one of the animations.
  idx_anim = randsample(length(anim1_sprite)-1,1);

  batch_labels.L1 = fix_wpn(batch_labels.L1, idx_anim);
  batch_labels.L2 = fix_wpn(batch_labels.L2, idx_anim);
  batch_labels.L3 = fix_wpn(batch_labels.L3, idx_anim);
  batch_labels.L4 = fix_wpn(batch_labels.L4, idx_anim);

  sprite1 = anim1_sprite{idx_anim};
  mask1 = anim1_mask{idx_anim};
  sprite2 = anim2_sprite{idx_anim};
  mask2 = anim2_mask{idx_anim};

  t1_idx = randsample(1:size(mask1,2), pars.batchsize, 1);
  t2_idx = t1_idx(randperm(numel(t1_idx)));

  batch_sprites.X1 = sprite1(:,t1_idx);
  batch_sprites.X2 = sprite1(:,t2_idx);
  batch_sprites.X3 = sprite2(:,t1_idx);
  batch_sprites.X4 = sprite2(:,t2_idx);

  batch_masks.X1 = mask1(:,t1_idx);
  batch_masks.X2 = mask1(:,t2_idx);
  batch_masks.X3 = mask2(:,t1_idx);
  batch_masks.X4 = mask2(:,t2_idx);

  batch_sprites.X1 = reshape([ batch_sprites.X1 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X2 = reshape([ batch_sprites.X2 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X3 = reshape([ batch_sprites.X3 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X4 = reshape([ batch_sprites.X4 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);

  batch_masks.X1 = reshape([ batch_masks.X1 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X2 = reshape([ batch_masks.X2 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X3 = reshape([ batch_masks.X3 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X4 = reshape([ batch_masks.X4 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
end

function [ batch_sprites, batch_masks, batch_labels ] = ...
  sample_analogy_dis(files, pars)

  batch_sprites = struct;
  batch_masks = struct;
  batch_labels = struct;

  % randomly pick two characters.
  idx = randsample(size(files,1),2);
  idx1 = idx(1);
  idx2 = idx(2);

  data1 = load(files{idx1});
  data2 = load(files{idx2});
  batch_labels.L1 = data1.labels;
  batch_labels.L2 = data2.labels;
  batch_labels.L3 = data1.labels;

  anim1_sprite = data1.sprites;
  anim1_mask = data1.masks;
  anim2_sprite = data2.sprites;
  anim2_mask = data2.masks;

  % randomly pick one of the animations.
  idx_anim1 = randsample(length(anim1_sprite)-1,1);
  idx_anim2 = randsample(length(anim1_sprite)-1,1);
  
  sprite1 = anim1_sprite{idx_anim1};
  mask1 = anim1_mask{idx_anim1};
  sprite2 = anim2_sprite{idx_anim2};
  mask2 = anim2_mask{idx_anim2};
  sprite3 = anim1_sprite{idx_anim2};
  mask3 = anim1_mask{idx_anim2};

  batch_labels.L1 = fix_wpn(batch_labels.L1, idx_anim1);
  batch_labels.L2 = fix_wpn(batch_labels.L2, idx_anim2);
  batch_labels.L3 = fix_wpn(batch_labels.L3, idx_anim1);

  t1_idx = randsample(1:size(mask1,2), pars.batchsize, 1);
  t2_idx = randsample(1:size(mask2,2), pars.batchsize, 1);

  batch_sprites.X1 = sprite1(:,t1_idx);
  batch_sprites.X2 = sprite2(:,t2_idx);
  batch_sprites.X3 = sprite3(:,t2_idx);

  batch_masks.X1 = mask1(:,t1_idx);
  batch_masks.X2 = mask2(:,t2_idx);
  batch_masks.X3 = mask3(:,t2_idx);

  batch_sprites.X1 = reshape([ batch_sprites.X1 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X2 = reshape([ batch_sprites.X2 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X3 = reshape([ batch_sprites.X3 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);

  batch_masks.X1 = reshape([ batch_masks.X1 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X2 = reshape([ batch_masks.X2 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X3 = reshape([ batch_masks.X3 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
end

function L = fix_wpn(L, idx_anim)
  if L(7)==0, % bow
      if ((idx_anim >= 17) && (idx_anim <= 20)),
          L(7) = 1;
      else
          L(7) = 0;
      end
  elseif L(7)==1, % spear
      if ((idx_anim >= 5) && (idx_anim <= 12)),
          L(7) = 2;
      else
          L(7) = 0;
      end    
  end
end

function L = flatten(Lin, pars)
    idx = 1;
    Lin = Lin + 1;
    L = zeros(sum(pars.card),pars.batchsize,'single');
    for l = 1:length(pars.card),
        curidx = idx:(idx + pars.card(l) - 1);
        L(curidx,:) = repmat(oneofc(Lin(l), pars.card(l)), 1, pars.batchsize);
        idx = idx + pars.card(l);
    end
end
