% set use_gpu=1 to train on GPU.
% if specified, model will be initialized from checkpoint saved in the file
%   init_from.
function train_sprites_fewshot(use_gpu, model, masked, alpha, holdout, numshot_idx, init_from)

if ~exist('use_gpu', 'var')
    use_gpu = 0;
end

if ~exist('model', 'var'),
    model = 'dis_cls';
end

if ~exist('masked', 'var'),
    masked = 0;
end

if ~exist('alpha', 'var'),
    alpha = 10.0;
end

if ~exist('holdout', 'var'),
    holdout = 1;
end

if ~exist('numshot_idx', 'var'),
    % choice of [6, 12, 24, 48].
    numshot_idx = 4;
end

solver_enc = 'analogy_sprite_enc_solver.prototxt';
solver_dec = 'analogy_sprite_dec_solver.prototxt';

addpath('../util');
addpath('../../matlab/caffe');
addpath('../model');

subdir = 'sprites_fewshot';
solver_enc = sprintf('results/%s/%s', subdir, solver_enc);
solver_dec = sprintf('results/%s/%s', subdir, solver_dec);

if exist('init_from', 'var'),
    matcaffe_init_sprites_train(use_gpu, solver_enc, solver_dec, init_from);
else
    matcaffe_init_sprites_train(use_gpu, solver_enc, solver_dec);
end

switch(model),
    case 'dis'
        sampler = @(data,pars,batch) sample_dis_fewshot(data, pars, batch);
        update = @(img,mask,labels,pars) update_dis(img,mask,labels,pars);
    case 'dis_cls'
        sampler = @(data,pars,batch) sample_dis_fewshot(data, pars, batch);
        update = @(img,mask,labels,pars) update_dis_cls(img,mask,labels,pars);
    case 'add'
        sampler = @(data,pars,batch) sample_add_fewshot(data, pars, batch);
        update = @(img,mask,labels,pars) update_add(img,mask,labels,pars);
end

%% Load data.
datadir = '../data/sprites';
splits = load([datadir '/sprites_splits.mat']);
files = dir([datadir '/*.mat']);
tmp = files(splits.trainidx);
trainfiles = arrayfun(@(x) [ datadir '/' x.name ], tmp, 'UniformOutput', false);

hsplits = load([datadir '/sprites_fewshot_splits.mat']);
holdout_ids = hsplits.holdout_ids{numshot_idx};
holdout_anims = hsplits.holdout_anims(holdout,:);
numshot = numel(holdout_ids);

%% Init model.
pars = struct('numstep', 100000, ...
              'alpha', alpha, ...
              'numhid', 256, ...
              'vdim', 60, ...
              'masked', masked, ...
              'holdout', holdout, ...
              'numshot', numshot, ...
              'model', model, ...
              'lambda_mask', 1, ...
              'lambda_rgb', 10, ...
              'batchsize', 25);
pars.fname = sprintf('sprite_fewshot_%s_t%d_n%d_a%.4g_m%d_%s', ...
                     model, holdout, numshot, alpha, masked, datestr(now,30));
pars.fname_png = sprintf('images/%s/%s.png', subdir, pars.fname);
mkdir(sprintf('images/%s', subdir))
pars.fname_mat = sprintf('results/%s/%s.mat', subdir, pars.fname);
pars.card = [ 2, 4, 3, 6, 2, 2, 2 ];
num_categorical = 0;
for l = 1:length(pars.card),
    num_categorical = num_categorical + pars.card(l);
end
pars.card(7) = 3;
num_categorical = num_categorical + 1;
pars.num_categorical = num_categorical;
if strcmp(model,'dis_cls'),
    pars.id_idx = 1:pars.num_categorical;
    pars.pose_idx = (pars.num_categorical+1):pars.numhid;
else
    pars.id_idx = 1:128;
    pars.pose_idx = 129:256;
end
pars.holdout_ids = holdout_ids;
pars.holdout_anims = holdout_anims;

%% Train.
caffe('set_phase_train');
for b = 1:pars.numstep,
    [ batch_data, batch_masks, batch_labels ] = sampler(trainfiles, pars, b);
    [ pred_sprite, cost, acc ] = update(batch_data, batch_masks, batch_labels, pars);

    % demo analogy.
    if mod(b,11)==0,
        fprintf(1, 'step %d of %d, acc=%.4g, cost=%g\n', b, pars.numstep, acc, cost);
        subplot(1,4,1);
        imagesc(batch_data.X1(:,:,:,1));
        subplot(1,4,2);
        imagesc(batch_data.X2(:,:,:,1));
        subplot(1,4,3);
        imagesc(batch_data.X3(:,:,:,1));
        subplot(1,4,4);
        imagesc(mytrim(pred_sprite(:,:,:,1), 0, 1));
        drawnow;
        print('-dpng', pars.fname_png);
    end

    if mod(b,1000)==0,
        model = struct('enc', caffe('get_weights', 0), ...
                       'dec', caffe('get_weights', 1), ...
                       'pars', pars);
        save(pars.fname_mat, '-struct', 'model');        
    end
end

end

function [ batch_sprites, batch_masks, batch_labels ] = ...
    sample_add_fewshot(files, pars, batchidx)
  batch_sprites = struct;
  batch_masks = struct;
  batch_labels = struct;

  if mod(batchidx,2)==0, % X1 or X2 has a held out pose (few shot).
    holdout_pose = 1;
    idx = randsample(pars.holdout_ids,2);
    idx1 = idx(1);
    idx2 = idx(2);
  else          % no held-out poses involved.
    holdout_pose = 0;
    idx = randsample(numel(files),2);
    idx1 = idx(1);
    idx2 = idx(2);
  end

  data1 = load(files{idx1});
  data2 = load(files{idx2});
  batch_labels.L1 = data1.labels;
  batch_labels.L2 = data1.labels;
  batch_labels.L3 = data2.labels;
  batch_labels.L4 = data2.labels;

  anim1_sprite = data1.sprites;
  anim1_mask = data1.masks;
  anim2_sprite = data2.sprites;
  anim2_mask = data2.masks;

  non_holdout_pose = setdiff(1:length(anim1_sprite)-1,pars.holdout_anims);
  if holdout_pose,
    idx_anim1 = randsample(non_holdout_pose,1);
    idx_anim2 = randsample(pars.holdout_anims,1);
  else
    idx_anim1 = randsample(non_holdout_pose,1);
    idx_anim2 = randsample(non_holdout_pose,1);
  end

  sprite1 = anim1_sprite{idx_anim1};
  mask1 = anim1_mask{idx_anim1};
  sprite2 = anim1_sprite{idx_anim2};
  mask2 = anim1_mask{idx_anim2};
  sprite3 = anim2_sprite{idx_anim1};
  mask3 = anim2_mask{idx_anim1};
  sprite4 = anim2_sprite{idx_anim2};
  mask4 = anim2_mask{idx_anim2};

  batch_labels.L1 = fix_wpn(batch_labels.L1, idx_anim1);
  batch_labels.L2 = fix_wpn(batch_labels.L2, idx_anim2);
  batch_labels.L3 = fix_wpn(batch_labels.L3, idx_anim1);
  batch_labels.L4 = fix_wpn(batch_labels.L4, idx_anim2);

  t1_idx = randsample(1:size(mask1,2), pars.batchsize, 1);
  t2_idx = randsample(1:size(mask2,2), pars.batchsize, 1);

  batch_sprites.X1 = sprite1(:,t1_idx);
  batch_sprites.X2 = sprite2(:,t2_idx);
  batch_sprites.X3 = sprite3(:,t1_idx);
  batch_sprites.X4 = sprite4(:,t2_idx);

  batch_masks.X1 = mask1(:,t1_idx);
  batch_masks.X2 = mask2(:,t2_idx);
  batch_masks.X3 = mask3(:,t1_idx);
  batch_masks.X4 = mask4(:,t2_idx);

  batch_sprites.X1 = reshape([ batch_sprites.X1 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X2 = reshape([ batch_sprites.X2 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X3 = reshape([ batch_sprites.X3 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X4 = reshape([ batch_sprites.X4 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);

  batch_masks.X1 = reshape([ batch_masks.X1 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X2 = reshape([ batch_masks.X2 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X3 = reshape([ batch_masks.X3 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X4 = reshape([ batch_masks.X4 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
end

function [ batch_sprites, batch_masks, batch_labels ] = ...
    sample_dis_fewshot(files, pars, batchidx)
  batch_sprites = struct;
  batch_masks = struct;
  batch_labels = struct;

  if mod(batchidx,2)==0, % X1 and X2 from holdout set, X3 from holdout.
    holdout_pose = 1;
    idx = randsample(pars.holdout_ids,2,1);
    idx1 = idx(1);
    idx2 = idx(2);
  else          % X1 from holdout, X2 from training, X3 from training.
    holdout_pose = 0;
    idx1 = randsample(1:numel(files),1);
    idx2 = randsample(1:numel(files),1);
  end

  data1 = load(files{idx1});
  data2 = load(files{idx2});
  batch_labels.L1 = data1.labels;
  batch_labels.L2 = data2.labels;
  batch_labels.L3 = data1.labels;

  anim1_sprite = data1.sprites;
  anim1_mask = data1.masks;
  anim2_sprite = data2.sprites;
  anim2_mask = data2.masks;

  non_holdout_pose = setdiff(1:length(anim1_sprite)-1,pars.holdout_anims);
  if holdout_pose,
    idx_anim1 = randsample(non_holdout_pose,1);
    idx_anim2 = randsample(pars.holdout_anims,1);
  else
    idx_anim1 = randsample(non_holdout_pose,1);
    idx_anim2 = randsample(non_holdout_pose,1);
  end
  
  sprite1 = anim1_sprite{idx_anim1};
  mask1 = anim1_mask{idx_anim1};
  sprite2 = anim2_sprite{idx_anim2};
  mask2 = anim2_mask{idx_anim2};
  sprite3 = anim1_sprite{idx_anim2};
  mask3 = anim1_mask{idx_anim2};

  batch_labels.L1 = fix_wpn(batch_labels.L1, idx_anim1);
  batch_labels.L2 = fix_wpn(batch_labels.L2, idx_anim2);
  batch_labels.L3 = fix_wpn(batch_labels.L3, idx_anim1);

  t1_idx = randsample(1:size(mask1,2), pars.batchsize, 1);
  t2_idx = randsample(1:size(mask2,2), pars.batchsize, 1);

  batch_sprites.X1 = sprite1(:,t1_idx);
  batch_sprites.X2 = sprite2(:,t2_idx);
  batch_sprites.X3 = sprite3(:,t2_idx);

  batch_masks.X1 = mask1(:,t1_idx);
  batch_masks.X2 = mask2(:,t2_idx);
  batch_masks.X3 = mask3(:,t2_idx);

  batch_sprites.X1 = reshape([ batch_sprites.X1 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X2 = reshape([ batch_sprites.X2 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);
  batch_sprites.X3 = reshape([ batch_sprites.X3 ], ...
                             [ pars.vdim, pars.vdim, 3, pars.batchsize ]);

  batch_masks.X1 = reshape([ batch_masks.X1 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X2 = reshape([ batch_masks.X2 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
  batch_masks.X3 = reshape([ batch_masks.X3 ], ...
                           [ pars.vdim, pars.vdim, 1, pars.batchsize ]);
end

function [ pred_sprite, cost, acc ] = ...
        update_dis_cls(batch_data, batch_masks, batch_labels, pars)
    Xin = { single(batch_data.X1); single(batch_data.X2); ...
            single(batch_data.X3); };

    % Encoder.
    result = caffe('forward', Xin, 0);
    hid2 = squeeze(result{1});
    hid3 = squeeze(result{2});
    hid1 = squeeze(result{3});

    % Regularizers.
    [ cost_hid, deltas_hid, acc ] = ...
        grad_hid_disc2(hid1, hid2, batch_labels.L1, batch_labels.L2, pars);

    % Replace ID units with attributes.
    top = hid2;
    L3 = myflatten(batch_labels.L3, pars);
    top(pars.id_idx,:) = L3;

    % Decoder.
    result = caffe('forward', { top }, 1);
    pred_mask = result{1};
    pred_sprite = result{2};

    % Reconstruction cost.
    [ cost_recon, delta_sprite, delta_mask ] = ...
        grad_recon(pred_sprite, pred_mask, batch_data.X3, batch_masks.X3, pars);

    % Update decoder.
    delta_recon = caffe('backward', { delta_mask; delta_sprite }, 1);
    delta_recon = squeeze(delta_recon{1});
    caffe('update', 1);

    delta_hid1 = deltas_hid{1};
    delta_hid2 = deltas_hid{2};
    delta_hid2(pars.pose_idx,:) = delta_recon(pars.pose_idx,:);
    delta_hid3 = 0*hid3;

    % Update encoder.
    caffe('backward', { delta_hid2; delta_hid3; delta_hid1 }, 0);
    caffe('update', 0);
    
    cost = cost_hid + cost_recon;
end

function [ pred_sprite, cost, acc ] = ...
        update_dis(batch_data, batch_masks, batch_labels, pars)
    acc = 0;
    Xin = { single(batch_data.X1); single(batch_data.X2); ...
            single(batch_data.X3); };

    % Encoder.
    result = caffe('forward', Xin, 0);
    hid2 = squeeze(result{1});
    hid3 = squeeze(result{2});
    hid1 = squeeze(result{3});

    % Replace ID units with attributes.
    top = 0*hid1;
    top(pars.id_idx,:) = hid1(pars.id_idx,:);
    top(pars.pose_idx,:) = hid2(pars.pose_idx,:);

    % Decoder.
    result = caffe('forward', { top }, 1);
    pred_mask = result{1};
    pred_sprite = result{2};

    % Reconstruction cost.
    [ cost_recon, delta_sprite, delta_mask ] = ...
        grad_recon(pred_sprite, pred_mask, batch_data.X3, batch_masks.X3, pars);

    % Update decoder.
    delta_recon = caffe('backward', { delta_mask; delta_sprite }, 1);
    delta_recon = squeeze(delta_recon{1});
    caffe('update', 1);

    delta_hid1 = 0*delta_recon;
    delta_hid1(pars.id_idx,:) = delta_recon(pars.id_idx,:);
    delta_hid2 = 0*delta_recon;
    delta_hid2(pars.pose_idx,:) = delta_recon(pars.pose_idx,:);
    delta_hid3 = 0*hid3;

    % Update encoder.
    caffe('backward', { delta_hid2; delta_hid3; delta_hid1 }, 0);
    caffe('update', 0);
    
    cost = cost_recon;
end

function [ pred_sprite, cost, acc ] = ...
        update_add(batch_data, batch_masks, batch_labels, pars)
    acc = 0;
    Xin = { single(batch_data.X1); single(batch_data.X2); ...
            single(batch_data.X3) };

    % Encoder.
    result = caffe('forward', Xin, 0);
    out = squeeze(result{1});
    query = squeeze(result{2});
    ref = squeeze(result{3});

    % Replace ID units with attributes.
    top = query + out - ref;

    % Decoder.
    result = caffe('forward', { top }, 1);
    pred_mask = result{1};
    pred_sprite = result{2};

    % Reconstruction cost.
    [ cost_recon, delta_sprite, delta_mask ] = ...
        grad_recon(pred_sprite, pred_mask, batch_data.X3, batch_masks.X3, pars);

    % Update decoder.
    delta_recon = caffe('backward', { delta_mask; delta_sprite }, 1);
    delta_recon = squeeze(delta_recon{1});
    caffe('update', 1);

    delta_ref = -delta_recon;
    delta_out = delta_recon;
    delta_query = delta_recon;

    % Update encoder.
    caffe('backward', { delta_out; delta_query; delta_ref; 0*delta_ref }, 0);
    caffe('update', 0);
    
    cost = cost_recon;
end
