
%% Load model.
addpath('../../matlab/caffe');
addpath('../util');
addpath('../model');
%numhid = 256;
numhid = 512;
use_gpu = 2;

%init_from = 'results/sprites_manifold/analogy_sprite_large_dis_a10_m0_20150531T184216.mat';
%init_from = 'results/sprites_manifold2/analogy_sprite_manifold2_a10.0000_b0.0100_m0_e0_20150601T193938.mat';
%init_from = 'results/sprites_manifold2/analogy_sprite_manifold2_a10.0000_b0.0100_m0_e1_20150601T193836.mat';
init_from = 'results/sprites/analogy_sprites.mat';
model = load(init_from);
pars = model.pars;
matcaffe_init_sprites_train(use_gpu, ...
                  'results/sprites/analogy_sprite_enc_solver.prototxt', ...
                  'results/sprites/analogy_sprite_dec_solver.prototxt', ...
                  init_from);

%% Load data.
datadir = '../data/sprites';
splits = load([datadir '/sprites_splits.mat']);
files = dir([datadir '/*.mat']);
labels = load([datadir '/sprite_large_labels.txt']);

batch_sprites = struct;
batch_masks = struct;
batch_labels = struct;

testpairs = combnk(splits.testidx,2);
rng('default');
rng(1);
R = randperm(size(testpairs,1));
testpairs = testpairs(R,:);
% walk
id1_w = testpairs(1,1); id2_w = testpairs(1,2); anim1_w = 11; anim2_w = 11; anim3_w = 10; f1_w = 2; f2_w = 3; f3_w = 1;
% thrust
id1_s = testpairs(16,1); id2_s = testpairs(16,2); anim1_s = 6; anim2_s = 6; anim3_s = 8; f1_s = 4; f2_s = 5; f3_s = 4;
% rotate
%id1_r = testpairs(17,1); id2_r = testpairs(17,2); anim1_r = 7; anim2_r = 7; anim3_r = 10; f1_r = 1; f2_r = 1; f3_r = 1;
% shoot
id1_r = testpairs(12,1); id2_r = testpairs(12,2); anim1_r = 18; anim2_r = 18; anim3_r = 20; f1_r = 4; f2_r = 5; f3_r = 2;

inp = [ id1_w, id2_w, anim1_w, anim2_w, anim3_w, f1_w, f2_w, f3_w; ...
        id1_s, id2_s, anim1_s, anim2_s, anim3_s, f1_s, f2_s, f3_s; ...
        id1_r, id2_r, anim1_r, anim2_r, anim3_r, f1_r, f2_r, f3_r ];

manifold = 0;
preds = cell(3,1);
refs = cell(3,1);
for a = 1:size(inp,1),
    fprintf('input %d of %d\n', a, size(inp,1));
    id1 = inp(a,1);
    id2 = inp(a,2);
    anim1 = inp(a,3);
    anim2 = inp(a,4);
    anim3 = inp(a,5);
    f1 = inp(a,6);
    f2 = inp(a,7);
    f3 = inp(a,8);
    
    data1 = load(sprintf('%s/sprites_%d.mat', datadir, id1));
    sprite1 = data1.sprites;
    data2 = load(sprintf('%s/sprites_%d.mat', datadir, id2));
    sprite2 = data2.sprites;
    nf1 = size(sprite1{anim1},2);
    nf2 = size(sprite2{anim2},2);
    batch_sprites.X1 = ...
        reshape(sprite1{anim1}, pars.vdim, pars.vdim, 3, nf1);
    batch_sprites.X2 = ...
        reshape(sprite2{anim1}, pars.vdim, pars.vdim, 3, nf2);
    batch_sprites.X3 = ...
        reshape(sprite2{anim2}, pars.vdim, pars.vdim, 3, nf2);
    pred_anim = zeros(pars.vdim, pars.vdim,3, nf2);
    pred_anim(:,:,:,1) = batch_sprites.X2(:,:,:,1);
    
    tmp1 = zeros(60,60,3,25,'single');
    tmp2 = zeros(60,60,3,25,'single');
    tmp1(:,:,:,1:nf1) = single(batch_sprites.X1(:,:,:,1:nf1));
    tmp2(:,:,:,1:nf2) = single(batch_sprites.X2(:,:,:,1:nf2));
    Xin = { tmp1; tmp2; tmp1; tmp2 };

    % Encoder.
    result = caffe('forward', Xin, 0);
    out = squeeze(result{1});
    query = squeeze(result{2});
    ref = squeeze(result{3});
    target = squeeze(result{4});

    trans = out - ref;
    trans = trans(pars.pose_idx,:);
    top_id = out(pars.id_idx,:);
    top_id = apply_softmax(top_id, pars);
    top_pose = ref(pars.pose_idx,:);
    sw = zeros(size(trans), 'single');

    % Decoder.
    result = caffe('forward', ...
                   {0*trans, top_id, top_pose, sw, top_id }, 1);
    hid_new = squeeze(result{1});
    pred_mask = result{2};
    pred_sprite = result{3};

    pred = bsxfun(@times, pred_sprite, pred_mask);
    pred_anim(:,:,:,2:nf2) = pred(:,:,:,2:nf2);

    batch_sprites.X3 = pred(:,:,:,1:nf2);
    top_id = hid_new(pars.id_idx,:);
    top_pose = hid_new(pars.pose_idx,:);

    refs{a} = tmp1(:,:,:,1:nf1);
    preds{a} = pred_anim;
end

%%
sep = 2;
vis_out = ones(3*60 + 2*sep, nf2*60 + 8*sep, 3);
vis_inp = vis_out;
rowstart = 1;
for row = 1:3,
    rowend = rowstart + 60 - 1;
    colstart = 1;
    for col = 1:size(preds{row},4),
        colend = colstart + 60 - 1;
        vis_ref(rowstart:rowend,colstart:colend,:) = refs{row}(:,:,:,col);
        vis_out(rowstart:rowend,colstart:colend,:) = preds{row}(:,:,:,col);
        colstart = colend + sep + 1;
    end
    rowstart = rowend + sep + 1;
end
figure(1);
imagesc(vis_ref); axis off;
figure(2);
imagesc(vis_out); axis off;

%% Generate video pairs.
close all;

%% walk ref.
fr = 10;
writerObj = VideoWriter('walking_ref.avi');
writerObj.FrameRate = fr;
open(writerObj);
axis tight
set(gca,'nextplot','replacechildren');
set(gcf,'Renderer','zbuffer');
for k = 1:size(refs{1},4),
    imagesc(refs{1}(end:-1:1,:,:,k));
    frame = getframe;
    writeVideo(writerObj,frame);
end
close(writerObj); 

% walk gen.
writerObj = VideoWriter('walking_gen.avi');
writerObj.FrameRate = fr;
open(writerObj);
axis tight
set(gca,'nextplot','replacechildren');
set(gcf,'Renderer','zbuffer');
for k = 1:size(refs{1},4),
    imagesc(preds{1}(end:-1:1,:,:,k));
    frame = getframe;
    writeVideo(writerObj,frame);
end
close(writerObj); 

% thrust ref.
writerObj = VideoWriter('thrust_ref.avi');
writerObj.FrameRate = fr;
open(writerObj);
axis tight
set(gca,'nextplot','replacechildren');
set(gcf,'Renderer','zbuffer');
for k = 1:size(refs{2},4),
    imagesc(refs{2}(end:-1:1,:,:,k));
    frame = getframe;
    writeVideo(writerObj,frame);
end
close(writerObj); 

% thrust gen.
writerObj = VideoWriter('thrust_gen.avi');
writerObj.FrameRate = fr;
open(writerObj);
axis tight
set(gca,'nextplot','replacechildren');
set(gcf,'Renderer','zbuffer');
for k = 1:size(preds{2},4),
    imagesc(preds{2}(end:-1:1,:,:,k));
    frame = getframe;
    writeVideo(writerObj,frame);
end
close(writerObj); 

% shoot ref.
writerObj = VideoWriter('shoot_ref.avi');
writerObj.FrameRate = fr;
open(writerObj);
axis tight
set(gca,'nextplot','replacechildren');
set(gcf,'Renderer','zbuffer');
for k = 1:size(refs{3},4),
    imagesc(refs{3}(end:-1:1,:,:,k));
    frame = getframe;
    writeVideo(writerObj,frame);
end
close(writerObj); 

% shoot gen.
writerObj = VideoWriter('shoot_gen.avi');
writerObj.FrameRate = fr;
open(writerObj);
axis tight
set(gca,'nextplot','replacechildren');
set(gcf,'Renderer','zbuffer');
for k = 1:size(preds{3},4),
    imagesc(preds{3}(end:-1:1,:,:,k));
    frame = getframe;
    writeVideo(writerObj,frame);
end
close(writerObj); 

%%
%a=load('sprite_manifold_e0.mat');
%b=load('sprite_manifold_e1.mat');
%vis = a.vis;
%vis(end-60:end,:,:) = b.vis(end-60:end,:,:);
%imagesc(vis); axis off;
%print('-djpeg','sprite_manifold_clean');
