
%% Load model.
addpath('../util');
addpath('../../matlab/caffe');
imgsize = [ 60 60 3 ];
numhid = 256;
use_gpu = 2;
%{
% add t1 n6
init_from = 'results/sprites_fewshot/sprite_fewshot_add_t1_n5_a0_m0_20150518T180657.mat';
% add t1 n12
init_from = 'results/sprites_fewshot/sprite_fewshot_add_t1_n5_a0_m0_20150518T180723.mat';
% add t1 n24
init_from = 'results/sprites_fewshot/sprite_fewshot_add_t1_n5_a0_m0_20150518T180733.mat';
% add t1 n48
init_from = 'results/sprites_fewshot/sprite_fewshot_add_t1_n5_a0_m0_20150518T180742.mat';
%}
%{
% dis t1 n6
init_from = 'results/sprites_fewshot/sprite_fewshot_dis_t1_n6_a0_m0_20150519T124721.mat';
% dis t1 n12
init_from = 'results/sprites_fewshot/sprite_fewshot_dis_t1_n5_a0_m0_20150518T181008.mat';
% dis t1 n24
init_from = 'results/sprites_fewshot/sprite_fewshot_dis_t1_n5_a0_m0_20150518T181017.mat';
% dis t1 n48
init_from = 'results/sprites_fewshot/sprite_fewshot_dis_t1_n48_a0_m0_20150519T124836.mat';
% dis_cls t1 n6
init_from = 'results/sprites_fewshot/sprite_fewshot_dis_cls_t1_n5_a10_m0_20150518T181034.mat';
% dis_cls t1 n12
init_from = 'results/sprites_fewshot/sprite_fewshot_dis_cls_t1_n5_a10_m0_20150518T181029.mat';
% dis_cls t1 n24
init_from = 'results/sprites_fewshot/sprite_fewshot_dis_cls_t1_n24_a10_m0_20150519T180359.mat';
%}
% dis_cls t1 n48
init_from = 'results/sprites_fewshot/sprite_fewshot_dis_cls_t1_n48_a10_m0_20151024T213125.mat';
%

model = load(init_from);
pars = model.pars;
matcaffe_init_sprites_train(use_gpu, ...
                  'results/sprites_fewshot/analogy_sprite_enc_solver.prototxt', ...
                  'results/sprites_fewshot/analogy_sprite_dec_solver.prototxt', ...
                  init_from);

%% Load data.
datadir = '../data/sprites';
splits = load([datadir '/sprites_splits.mat']);
files = dir([datadir '/*.mat']);
labels = load([datadir '/sprite_large_labels.txt']);
tmp = files(splits.trainidx);
trainfiles = arrayfun(@(x) [ datadir '/' x.name ], tmp, 'UniformOutput', false);
trainlabels = labels(splits.trainidx,:);
train_ids = setdiff(1:numel(trainfiles), pars.holdout_ids);
non_holdout_anim = setdiff(1:20, pars.holdout_anims);

%% Evaluate few shot performance on train set holdout animations.

% Extract pose features for few-shot animations.
pose_fea = cell(numel(pars.holdout_anims),1);

% If 'dis', we just directly extract the pose features.
% Otherwise, we need to extract the *transformation* from a fixed starting
% pose to each held-out pose.

if ~strcmp(pars.model,'add'),
    for i = 1:numel(pars.holdout_anims),
        for j = 1:numel(pars.holdout_ids),
            fprintf(1,'Extracting pose features for anim %d from char %d\n', ...
                    pars.holdout_anims(i), pars.holdout_ids(j));
            data = load(trainfiles{pars.holdout_ids(j)});
            inp = data.sprites{pars.holdout_anims(i)};
            inp = reshape(single(inp), [60 60 3 size(inp,2) ]);
            tmp = zeros([60 60 3 pars.batchsize],'single');
            tmp(:,:,:,1:size(inp,4)) = inp;
            res = caffe('forward', { tmp; tmp; tmp }, 0);
            hid_j = squeeze(res{1});
            if isempty(pose_fea{i}),
                pose_fea{i} = hid_j(:,1:size(inp,4));
            else
                pose_fea{i} = pose_fea{i} + hid_j(:,1:size(inp,4));
            end
        end
        pose_fea{i} = pose_fea{i} / numel(pars.holdout_ids);
    end
else
    anchor_anim = 7; %
    for i = 1:numel(pars.holdout_anims),
        for j = 1:numel(pars.holdout_ids),
            fprintf(1,'Extracting pose features for anim %d from char %d\n', ...
                    pars.holdout_anims(i), pars.holdout_ids(j));

            data = load(trainfiles{pars.holdout_ids(j)});
            
            inp_ref = data.sprites{anchor_anim};
            inp_ref = repmat(inp_ref(:,1),1,pars.batchsize);
            inp_ref = reshape(single(inp_ref), [ 60 60 3 pars.batchsize ]);

            inp_out = data.sprites{pars.holdout_anims(i)};
            nf = size(inp_out,2);
            inp_out = reshape(single(inp_out), [60 60 3 size(inp_out,2) ]);
            tmp = zeros([60 60 3 pars.batchsize],'single');
            tmp(:,:,:,1:size(inp_out,4)) = inp_out;
            inp_out = tmp;

            res = caffe('forward', { inp_ref; inp_out; inp_out }, 0);
            hid_out = squeeze(res{1});
            hid_ref = squeeze(res{3});

            trans = hid_out - hid_ref;
            if isempty(pose_fea{i}),
                pose_fea{i} = trans(:,1:nf);
            else
                pose_fea{i} = pose_fea{i} + trans(:,1:nf);
            end
            pose_fea{i} = trans(:,1:nf);
        end
    end
end
%% Transfer pose features to all training set images.
cost = 0;
if strcmp(pars.model,'add'),
    for i = 1:numel(train_ids),
        fprintf(1,'Computing cost for character %d of %d\n', ...
                i, numel(train_ids));
        data = load(trainfiles{train_ids(i)});

        % Compute query feature.
        inp_query = data.sprites{anchor_anim};
        inp_query = repmat(inp_query(:,1),1,pars.batchsize);
        inp_query = reshape(single(inp_query), [ 60 60 3 pars.batchsize ]);
        res = caffe('forward', { inp_query; inp_query; inp_query }, 0);
        hid_query = squeeze(res{1});
        hid_query = hid_query(:,1);

        % Now transfer the pose.
        for j = 1:numel(pars.holdout_anims),
            trans = pose_fea{j};
            inp = bsxfun(@plus, hid_query, trans);
            res = caffe('forward', { inp }, 1);
            rgb = res{2}(:,:,:,1:size(trans,2));
            gt = data.sprites{pars.holdout_anims(j)};

            err = rgb(:) - gt(:);
            cost = cost + mean(sqrt(sum(err.^2,1))) / numel(pars.holdout_anims);
        end
    end
else
    for i = 1:numel(train_ids),
        fprintf(1,'Computing cost for character %d of %d\n', ...
                i, numel(train_ids));
        data = load(trainfiles{train_ids(i)});
        id_fea = zeros(pars.numhid,1,'single');

        % Compute ID fea by averaging ID prediction across all animations.
        for j = 1:numel(non_holdout_anim),
            inp = data.sprites{non_holdout_anim(j)};
            inp = reshape(single(inp), [60 60 3 size(inp,2) ]);
            tmp = zeros([60 60 3 pars.batchsize],'single');
            tmp(:,:,:,1:size(inp,4)) = inp;
            res = caffe('forward', { tmp; tmp; tmp }, 0);
            hid_j = squeeze(res{1});
            id_fea = id_fea + mean(hid_j(:,1:size(inp,4)),2);
        end
        id_fea = id_fea / numel(non_holdout_anim);

        % Apply softmax to id_fea.
        %if ~strcmp(pars.model, 'add'),
        %    id_fea = apply_softmax(id_fea, pars);
        %end
        if strcmp(pars.model, 'dis_cls'),
            id_fea = apply_softmax(id_fea, pars);
        end

        % Now transfer the pose.
        for j = 1:numel(pars.holdout_anims),
            tmp = pose_fea{j};
            tmp(pars.id_idx,:) = repmat(id_fea(pars.id_idx), 1, size(tmp,2));
            inp = zeros(pars.numhid, pars.batchsize, 'single');
            inp(:,1:size(tmp,2)) = tmp;
            res = caffe('forward', { inp }, 1);
            rgb = res{2}(:,:,:,1:size(tmp,2));
            gt = data.sprites{pars.holdout_anims(j)};

            err = rgb(:) - gt(:);
            cost = cost + mean(sqrt(sum(err.^2,1))) / numel(pars.holdout_anims);
        end
    end
end
cost = cost / numel(train_ids);
disp('held out animations: ');
disp(pars.holdout_anims);
fprintf(1,'average cost = %.4g\n', cost);

%% Sanity check.
data1 = load(trainfiles{pars.holdout_ids(1)});
data2 = load(trainfiles{train_ids(4)});
anim1 = data1.sprites{7};
anim2 = data1.sprites{2};
anim3 = data2.sprites{7};
anim4 = data2.sprites{2};
img1 = single(reshape(anim1(:,6),[60 60 3 1]));
img2 = single(reshape(anim2(:,6),[60 60 3 1]));
img3 = single(reshape(anim3(:,6),[60 60 3 1]));
img4 = single(reshape(anim4(:,6),[60 60 3 1]));

img1 = repmat(img1,[1,1,1,pars.batchsize]);
img2 = repmat(img2,[1,1,1,pars.batchsize]);
img3 = repmat(img3,[1,1,1,pars.batchsize]);
img4 = repmat(img4,[1,1,1,pars.batchsize]);

inp = { img1; img2; img3 };
res = caffe('forward', inp, 0);
hid_out = squeeze(res{1});
hid_query = squeeze(res{2});
hid_ref = squeeze(res{3});
if strcmp(pars.model,'add'),
    top = hid_query + hid_out - hid_ref;
else
    top = hid_query;
    top(pars.id_idx,:) = hid_query(pars.id_idx,:);
    top(pars.pose_idx,:) = hid_out(pars.pose_idx,:);
    top = apply_softmax(top, pars);
end
res = caffe('forward', { top }, 1);
pred = res{2};
figure(1);
subplot(1,4,1);
imagesc(img1(:,:,:,1));
subplot(1,4,2);
imagesc(img2(:,:,:,1));
subplot(1,4,3);
imagesc(img3(:,:,:,1));
subplot(1,4,4);
%imagesc(pred(:,:,:,1));
imagesc(mytrim(pred(:,:,:,1),0,1));
save(sprintf('example_fewshot_%s.mat', pars.model), ...
     'img1', 'img2', 'img3', 'img4', 'pred');

%% Generate paper figure.
fig_add = load('example_fewshot_add.mat');
fig_dis = load('example_fewshot_dis.mat');
fig_dis_cls = load('example_fewshot_dis_cls.mat');
img4 = fig_dis_cls.img4;
vis = zeros(60,60,3,4,3);
figs = { fig_add, fig_dis, fig_dis_cls };
for f = 1:numel(figs),
    vis(:,:,:,1,f) = figs{f}.img1(:,:,:,1);
    vis(:,:,:,2,f) = figs{f}.img2(:,:,:,1);
    vis(:,:,:,3,f) = figs{f}.img3(:,:,:,1);
    vis(:,:,:,4,f) = figs{f}.pred(:,:,:,1);
end
vis = permute(vis,[1 5 2 4 3]);
vis = reshape(vis,[60*3, 60*4, 3]);
imagesc(vis); axis off;
savefig('sprites_fewshot_example','jpeg');
