function test_sprites_error_add_cls()
masked = 0;
use_gpu = 2;
alpha = 1;

addpath('../../../matlab/caffe');
addpath('../util');
addpath('../')

solver_enc = 'analogy_sprite_enc_solver.prototxt';
solver_dec = 'analogy_sprite_dec_solver.prototxt';

subdir = 'sprites_add';
solver_enc = sprintf('results/%s/%s', subdir, solver_enc);
solver_dec = sprintf('results/%s/%s', subdir, solver_dec);

init_from = '/mnt/neocortex2/scratch/yeezhang/nips_final_models/add_cls.mat';
matcaffe_init_sprites_train(use_gpu, solver_enc, solver_dec, init_from);


%% Load data.
datadir = '../data/sprites';
splits = load('../data/sprites_splits.mat');

files = dir([datadir '/*.mat']);
tmp = files(splits.testidx);
testfiles = arrayfun(@(x) [ datadir '/' x.name ], tmp, 'UniformOutput', false);

%% Init model.
pars = struct('numstep', 100, ...
              'numanim', 20, ...
              'alpha', alpha, ...
              'numhid', 1024, ...
              'vdim', 60, ...
              'masked', masked, ...
              'lambda_mask', 0, ...
              'lambda_rgb', 2, ...
              'numfactor', 8, ...
              'batchsize', 25);
pars.card = [ 2, 4, 3, 6, 2, 2, 2 ];
num_categorical = 0;
for l = 1:length(pars.card),
    num_categorical = num_categorical + pars.card(l);
end
pars.card(7) = 3;
num_categorical = num_categorical + 1;
pars.num_categorical = num_categorical;
pars.id_idx = 1:pars.num_categorical;
pars.pose_idx = (pars.num_categorical+1):pars.numhid;

rng('default');
rng(1);

caffe('set_phase_train');

data_idx = load('idx_hist');
idx_hist = data_idx.idx_hist; 
cost_hist = [];
for a = 1:pars.numanim,
cost = 0.0;
for b = 1:pars.numstep,
%[ batch_data, batch_masks, batch_labels, idx ] = sample_analogy(testfiles, pars, a);  
[ batch_data, batch_masks, batch_labels ] = sample_analogy(testfiles, pars, a, squeeze(idx_hist(a,b,:)));


    Xin = { single(batch_data.X1); single(batch_data.X2); ...
            single(batch_data.X3); single(batch_data.X4)};

    % Encoder.
    result = caffe('forward', Xin, 0);
    out = squeeze(result{1});
    query = squeeze(result{2});
    ref = squeeze(result{3});
    target = squeeze(result{4});

    % Vector addition for analogy
    top = out - ref + query;

    % Decoder.
    result = caffe('forward', { top }, 1);
    pred_mask = result{1};
    pred_sprite = result{2};

    % Reconstruction cost.
    [ cost_recon, delta_sprite, delta_mask ] = ...
        grad_recon(pred_sprite, pred_mask, batch_data.X4, batch_masks.X4, pars);
    
    cost = cost + cost_recon;
end
cost_hist(a) = cost / pars.numstep;
cost / pars.numstep;
end
('Mean-squared pixel error of the add+cls model')
disp('---------------------------------------------')
fprintf('Spell cast: %.2f\n', mean(cost_hist(1:4)));
fprintf('Thrust: %.2f\n', mean(cost_hist(5:8)));
fprintf('Walk: %.2f\n', mean(cost_hist(9:12)));
fprintf('Slash: %.2f\n', mean(cost_hist(13:16)));
fprintf('Shoot: %.2f\n', mean(cost_hist(17:20)));
disp('---------------------------------------------')
fprintf('Average: %.2f\n', mean(cost_hist));
end

function [ batch_sprites, batch_masks, batch_labels] = ...
  sample_analogy(files, pars, idx_anim, idx)

  batch_sprites = struct;
  batch_masks = struct;
  batch_labels = struct;

  idx1 = idx(1);
  idx2 = idx(2);

  data1 = load(files{idx1});
  data2 = load(files{idx2});
  batch_labels.L1 = data1.labels;
  batch_labels.L2 = data1.labels;
  batch_labels.L3 = data2.labels;
  batch_labels.L4 = data2.labels;

  anim1_sprite = data1.sprites;
  anim1_mask = data1.masks;
  anim2_sprite = data2.sprites;
  anim2_mask = data2.masks;

  % randomly pick one of the animations.
  %idx_anim = randsample(length(anim1_sprite)-1,1);
  

  batch_labels.L1 = fix_wpn(batch_labels.L1, idx_anim);
  batch_labels.L2 = fix_wpn(batch_labels.L2, idx_anim);
  batch_labels.L3 = fix_wpn(batch_labels.L3, idx_anim);
  batch_labels.L4 = fix_wpn(batch_labels.L4, idx_anim);

  sprite1 = anim1_sprite{idx_anim};
  mask1 = anim1_mask{idx_anim};
  sprite2 = anim2_sprite{idx_anim};
  mask2 = anim2_mask{idx_anim};

  t1_idx = randsample(1:size(mask1,2), pars.batchsize, 1);
  t2_idx = t1_idx(randperm(numel(t1_idx)));

  batch_sprites.X1 = sprite1(:,t1_idx);
  batch_sprites.X2 = sprite1(:,t2_idx);
  batch_sprites.X3 = sprite2(:,t1_idx);
  batch_sprites.X4 = sprite2(:,t2_idx);

  batch_masks.X1 = mask1(:,t1_idx);
  batch_masks.X2 = mask1(:,t2_idx);
  batch_masks.X3 = mask2(:,t1_idx);
  batch_masks.X4 = mask2(:,t2_idx);

  batch_sprites.X1 = reshape([ batch_sprites.X1 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X2 = reshape([ batch_sprites.X2 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X3 = reshape([ batch_sprites.X3 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X4 = reshape([ batch_sprites.X4 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);

  batch_masks.X1 = reshape([ batch_masks.X1 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X2 = reshape([ batch_masks.X2 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X3 = reshape([ batch_masks.X3 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X4 = reshape([ batch_masks.X4 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
end

function L = fix_wpn(L, idx_anim)
  if L(7)==0, % bow
      if ((idx_anim >= 17) && (idx_anim <= 20)),
          L(7) = 1;
      else
          L(7) = 0;
      end
  elseif L(7)==1, % spear
      if ((idx_anim >= 5) && (idx_anim <= 12)),
          L(7) = 2;
      else
          L(7) = 0;
      end    
  end
end

function L = flatten(Lin, pars)
    idx = 1;
    Lin = Lin + 1;
    L = zeros(sum(pars.card),pars.batchsize,'single');
    for l = 1:length(pars.card),
        curidx = idx:(idx + pars.card(l) - 1);
        L(curidx,:) = repmat(oneofc(Lin(l), pars.card(l)), 1, pars.batchsize);
        idx = idx + pars.card(l);
    end
end
