
%% Load model.
addpath('util');
addpath('../../matlab/caffe');
imgsize = [ 60 60 3 ];
numhid = 256;
%use_gpu = 2;
use_gpu = -1;

init_from = 'results/sprites_manifold/analogy_sprite_manifold_a10.0000_b0.0100_m0_e1_20150601T015802.mat';
model = load(init_from);
pars = model.pars;
matcaffe_init_sprites_train(use_gpu, ...
                  'results/sprites_manifold/analogy_sprite_enc_solver.prototxt', ...
                  'results/sprites_manifold/analogy_sprite_dec_solver.prototxt', ...
                  init_from);

%% Load data.
datadir = '/mnt/neocortex/scratch/reedscot/data/sprites';
splits = load('sprites_splits.mat');
files = dir([datadir '/*.mat']);
labels = load('sprite_large_labels.txt');
tmp = files(splits.testidx);
testfiles = arrayfun(@(x) [ datadir '/' x.name ], tmp, 'UniformOutput', false);
testlabels = labels(splits.testidx,:);
%tmp = files(splits.trainidx);
%testfiles = arrayfun(@(x) [ datadir '/' x.name ], tmp, 'UniformOutput', false);
%testlabels = labels(splits.trainidx,:);

%% Sanity check.
data1 = load(testfiles{5});
data2 = load(testfiles{9});
anim1 = data1.sprites{3};
anim2 = data1.sprites{3};
anim3 = data2.sprites{2};
anim4 = data2.sprites{3};
img1 = single(reshape(anim1(:,5),[60 60 3 1]));
img2 = single(reshape(anim2(:,6),[60 60 3 1]));
img3 = single(reshape(anim3(:,5),[60 60 3 1]));
img4 = single(reshape(anim4(:,6),[60 60 3 1]));

img1 = repmat(img1,[1,1,1,pars.batchsize]);
img2 = repmat(img2,[1,1,1,pars.batchsize]);
img3 = repmat(img3,[1,1,1,pars.batchsize]);
img4 = repmat(img4,[1,1,1,pars.batchsize]);

inp = { img1; img2; img3; img4 };
res = caffe('forward', inp, 0);
hid_out = squeeze(res{1});
hid_query = squeeze(res{2});
hid_ref = squeeze(res{3});

trans = hid_out - hid_ref;
top_id = hid_query(pars.id_idx,:);
top_pose = hid_query(pars.pose_idx,:);
sw = ones(size(trans), 'single');

result = caffe('forward', {trans, top_id, top_pose, sw }, 1);
hid_new = squeeze(result{1});
pred_mask = result{2};
pred_sprite = result{3};

figure(1);
subplot(1,4,1);
imagesc(img1(:,:,:,1));
subplot(1,4,2);
imagesc(img2(:,:,:,1));
subplot(1,4,3);
imagesc(img3(:,:,:,1));
subplot(1,4,4);
%imagesc(pred(:,:,:,1));
imagesc(mytrim(pred_sprite(:,:,:,1),0,1));

