
%% Load model.
addpath('../util');
addpath('../model');
addpath('../../matlab/caffe');
numhid = 512;
use_gpu = -1;

%init_from = 'results/sprites_transfer/analogy_sprite_transfer.mat';
init_from = '/home/reedscot/tmp/analogy_sprite_manifold_a10.0000_b20.0000_m0_e0_20151208T231549.mat'
model = load(init_from);
pars = model.pars;
matcaffe_init_sprites_train(use_gpu, ...
                  'results/sprites_transfer/analogy_sprite_enc_solver.prototxt', ...
                  'results/sprites_transfer/analogy_sprite_dec_solver.prototxt', ...
                  init_from);

%% Load data.
datadir = '../data/sprites';
splits = load([datadir '/sprites_splits.mat']);
files = dir([datadir '/*.mat']);
labels = load([datadir '/sprite_large_labels.txt']);

batch_sprites = struct;
batch_masks = struct;
batch_labels = struct;

batch_sprites.X1 = zeros(pars.vdim, pars.vdim,3, pars.batchsize);
batch_sprites.X2 = zeros(pars.vdim, pars.vdim,3, pars.batchsize);
batch_sprites.X3 = zeros(pars.vdim, pars.vdim,3, pars.batchsize);
batch_sprites.X4 = zeros(pars.vdim, pars.vdim,3, pars.batchsize);

testpairs = combnk(splits.testidx,2);
rng('default');
rng(1);
R = randperm(size(testpairs,1));
testpairs = testpairs(R,:);
% walk
id1_w = testpairs(1,1); id2_w = testpairs(1,2); anim1_w = 12; anim2_w = 10;
% shoot
id1_r = testpairs(12,1); id2_r = testpairs(12,2); anim1_r = 4; anim2_r = 3;
% thrust
id1_s = testpairs(16,1); id2_s = testpairs(16,2); anim1_s = 7; anim2_s = 8;

inp = [ id1_w, id2_w, anim1_w, anim2_w; ...
        id1_s, id2_s, anim1_s, anim2_s; ...
        id1_r, id2_r, anim1_r, anim2_r ];

manifold = 1;
preds = cell(3,1);
refs = cell(3,1);
for a = 1:size(inp,1),
    id1 = inp(a,1);
    id2 = inp(a,2);
    anim1 = inp(a,3);
    anim2 = inp(a,4);

    data1 = load(sprintf('%s/sprites_%d.mat', datadir, id1));
    sprite1 = data1.sprites;
    nf = size(sprite1{anim1},2);
    data2 = load(sprintf('%s/sprites_%d.mat', datadir, id2));
    sprite2 = data2.sprites;
    batch_sprites.X1(:,:,:,1:(nf-1)) = ...
        reshape(sprite1{anim1}(:,1:(nf-1)), pars.vdim, pars.vdim, 3, nf-1);
    batch_sprites.X2(:,:,:,1:(nf-1)) = ...
        reshape(sprite1{anim1}(:,2:nf), pars.vdim, pars.vdim, 3, nf-1);
    batch_sprites.X3(:,:,:,1) = ...
        reshape(sprite2{anim2}(:,1), pars.vdim, pars.vdim, 3);
    pred_anim = zeros(pars.vdim, pars.vdim, 3, nf);
    pred_anim(:,:,:,1) = batch_sprites.X3(:,:,:,1);

    for i = 1:(nf-1),
        fprintf(1,'inp %d step %d\n', a, i);
        if (i==1) || (manifold==0),
            Xin = { single(batch_sprites.X1); single(batch_sprites.X2); ...
                    single(batch_sprites.X3); single(batch_sprites.X4)};
            % Encoder.
            result = caffe('forward', Xin, 0);
            out = squeeze(result{1});
            query = squeeze(result{2});
            ref = squeeze(result{3});
            %target = squeeze(result{4});

            trans = out - ref;
            trans = trans(pars.pose_idx,:);
            top_id = query(pars.id_idx,:);
            top_id = apply_softmax(top_id, pars);
            top_pose = query(pars.pose_idx,:);

            sw = ones(size(trans), 'single');
        end
        trans(:,1) = trans(:,i);

        % Decoder.
        result = caffe('forward', ...
                       {trans, top_id, top_pose, sw }, 1);
        hid_new = squeeze(result{1});
        pred_mask = result{2};
        pred_sprite = result{3};

        pred = bsxfun(@times, pred_sprite, pred_mask);
        %{
        subplot(1,4,1);
        imagesc(single(batch_sprites.X1(:,:,:,i)));
        subplot(1,4,2);
        imagesc(single(batch_sprites.X2(:,:,:,i)));
        subplot(1,4,3);
        imagesc(single(batch_sprites.X3(:,:,:,i)));
        subplot(1,4,4);
        imagesc(single(pred(:,:,:,1)));
        %}
        pred_anim(:,:,:,i+1) = pred(:,:,:,1);

        %batch_sprites.X3 = pred;
        %top_id = hid_new(pars.id_idx,:);
        top_pose = hid_new(pars.pose_idx,:);
    end
    preds{a} = pred_anim;
    refs{a} = cat(4,batch_sprites.X1(:,:,:,1:nf-1),batch_sprites.X2(:,:,:,nf-1));
end
%{
%% Show references.
close all;
sep = 2;
vis = ones(3*60 + 2*sep, nf*60 + (nf-1)*sep, 3);
rowstart = 1;
for row = 1:3,
    rowend = rowstart + 60 - 1;
    colstart = 1;
    for col = 1:nf,
        colend = colstart + 60 - 1;        
        vis(rowstart:rowend,colstart:colend,:) = refs{row}(:,:,:,col);
        colstart = colend + sep + 1;
    end
    rowstart = rowend + sep + 1;
end
figure(1);
imagesc(vis); axis off;
%save('sprite_manifold_e0.mat', 'vis');
%save('sprite_manifold_e1.mat', 'vis');

%% Show predictions
sep = 2;
vis = ones(3*60 + 2*sep, nf*60 + (nf-1)*sep, 3);
rowstart = 1;
for row = 1:3,
    rowend = rowstart + 60 - 1;
    colstart = 1;
    for col = 1:nf,
        colend = colstart + 60 - 1;
        
        vis(rowstart:rowend,colstart:colend,:) = preds{row}(:,:,:,col);
        colstart = colend + sep + 1;
    end
    rowstart = rowend + sep + 1;
end
figure(2)
imagesc(vis); axis off;
%save('sprite_manifold_e0.mat', 'vis');
%save('sprite_manifold_e1.mat', 'vis');
%}
%%
%a=load('sprite_manifold_e0.mat');
%b=load('sprite_manifold_e1.mat');
%vis = a.vis;
%vis(end-60:end,:,:) = b.vis(end-60:end,:,:);
%imagesc(vis); axis off;
%print('-djpeg','sprite_manifold_clean');

%% Generate video pairs.
close all;

%% walk ref.
fr = 10;
writerObj = VideoWriter('walking_ref.avi');
writerObj.FrameRate = fr;
open(writerObj);
axis tight
set(gca,'nextplot','replacechildren');
set(gcf,'Renderer','zbuffer');
for k = 1:size(refs{1},4),
    imagesc(refs{1}(end:-1:1,:,:,k));
    frame = getframe;
    writeVideo(writerObj,frame);
end
close(writerObj); 

% walk gen.
writerObj = VideoWriter('walking_gen.avi');
writerObj.FrameRate = fr;
open(writerObj);
axis tight
set(gca,'nextplot','replacechildren');
set(gcf,'Renderer','zbuffer');
for k = 1:size(refs{1},4),
    imagesc(preds{1}(end:-1:1,:,:,k));
    frame = getframe;
    writeVideo(writerObj,frame);
end
close(writerObj); 

% thrust ref.
writerObj = VideoWriter('thrust_ref.avi');
writerObj.FrameRate = fr;
open(writerObj);
axis tight
set(gca,'nextplot','replacechildren');
set(gcf,'Renderer','zbuffer');
for k = 1:size(refs{2},4),
    imagesc(refs{2}(end:-1:1,:,:,k));
    frame = getframe;
    writeVideo(writerObj,frame);
end
close(writerObj); 

% thrust gen.
writerObj = VideoWriter('thrust_gen.avi');
writerObj.FrameRate = fr;
open(writerObj);
axis tight
set(gca,'nextplot','replacechildren');
set(gcf,'Renderer','zbuffer');
for k = 1:size(preds{2},4),
    imagesc(preds{2}(end:-1:1,:,:,k));
    frame = getframe;
    writeVideo(writerObj,frame);
end
close(writerObj); 

% shoot ref.
writerObj = VideoWriter('shoot_ref.avi');
writerObj.FrameRate = fr;
open(writerObj);
axis tight
set(gca,'nextplot','replacechildren');
set(gcf,'Renderer','zbuffer');
for k = 1:size(refs{3},4),
    imagesc(refs{3}(end:-1:1,:,:,k));
    frame = getframe;
    writeVideo(writerObj,frame);
end
close(writerObj); 

% shoot gen.
writerObj = VideoWriter('shoot_gen.avi');
writerObj.FrameRate = fr;
open(writerObj);
axis tight
set(gca,'nextplot','replacechildren');
set(gcf,'Renderer','zbuffer');
for k = 1:size(preds{3},4),
    imagesc(preds{3}(end:-1:1,:,:,k));
    frame = getframe;
    writeVideo(writerObj,frame);
end
close(writerObj); 
