
%% Load model.
addpath('../util');
addpath('../model');
addpath('../../matlab/caffe');
imgsize = [ 64 64 3 ];
numhid = 640;
use_gpu = 2;

subdir = 'cars_analogy';
init_from = sprintf('results/%s/analogy_cars_a0.01_m0_20151202T142653.mat', subdir);
model = load(init_from);
pars = model.pars;
matcaffe_init_cars_train(use_gpu, ...
                  sprintf('results/%s/analogy_enc_solver.prototxt', subdir), ...
                  sprintf('results/%s/analogy_dec_solver.prototxt', subdir), ...
                  init_from);

%% Load data.
datadir = '../data/cars';
trainids = 1:100; % 101-149 validation, 150:199 testing.
trainids([55,76,77,82,84,99]) = [];
testids = 150:199;
testids([31, 36, 48]) = [];

%ids = trainids;
ids = testids;

files = get_car_files(datadir, ids);
testpairs = combnk(1:numel(ids),2);
testpairs = testpairs(:,[2 1]);

%% Compute N-step recon error.
rng('default');
rng(1);
nstep = 7;
cost_nstep = zeros(size(testpairs,1)*pars.batchsize,nstep);

test_benchmark = zeros(size(testpairs,1),6);
idx = 0;
for p = 1:size(testpairs,1),
    id1 = testpairs(p,1);
    id2 = testpairs(p,2);
    %{
    try
        data1 = load(files{id1});
        data2 = load(files{id2});
    catch
        fprintf(1,'Could not load (%d,%d)\n', id1, id2);
        continue;
    end
    %}
    ept1 = randsample(24, 1);
    offset = randsample(-2:1:2, 1);

    ept2 = ept1 + offset';
    ept2(ept2 < 1) = 24 + ept2(ept2 < 1);
    ept2(ept2 > 24) = ept2(ept2 > 24) - 24;
    ept3 = randsample(24, 1);

    idx = idx + 1;
    test_benchmark(idx,:) = [ id1, id2, ept1, ept2, ept3, offset];
end
test_benchmark = test_benchmark(1:idx,:);
%save('cars_test_benchmark.mat','test_benchmark');

%load('cars_test_benchmark.mat');

%% Generate videos.
ids = [22, 13; ...
       42, 26;...
       19, 32; ...
       8, 39; ...
       14, 22; ...
       5, 8];
ept1 = [ 2, 4, 24, 3, 2, 24 ];
ept2 = [ 1, 3, 23, 2, 1, 23 ];
ept3 = [ 1, 1, 1, 1, 1, 1 ];
apt1 = [ 3, 3, 3, 3, 3, 3 ];
apt2 = [ 3, 3, 3, 3, 3, 3 ];
apt3 = [ 3, 3, 3, 3, 3, 3 ];

sz = [ 24, 4 ];
bsz = [ imgsize, 3, pars.batchsize ];

ns = length(apt1);
ref = zeros(64,64,3,pars.batchsize);
out = zeros(64,64,3,pars.batchsize);
query = zeros(64,64,3,pars.batchsize);
for n = 1:ns,
    id1 = ids(n,1);
    id2 = ids(n,2);
    data1 = load(files{id1});
    data2 = load(files{id2});

    ref(:,:,:,n) = imresize(single(data1.im(:,:,:,ept1(n),apt1(n)))/255., [64 64]);
    out(:,:,:,n) = imresize(single(data1.im(:,:,:,ept2(n),apt2(n)))/255., [64 64]);
    query(:,:,:,n) = imresize(single(data2.im(:,:,:,ept3(n),apt3(n)))/255., [64 64]);    
end

ref = single(ref);
out = single(out);
query = single(query);

res = caffe('forward', { ref; out; query; ref }, 0);
hid_out = squeeze(res{1});
hid_query = squeeze(res{2});
hid_ref = squeeze(res{3});
trans = hid_out - hid_ref;
pose_trans = hid_out(pars.pose_idx,:) - hid_ref(pars.pose_idx,:);
top_id = hid_query(pars.id_idx,:);
top_pose = hid_query(pars.pose_idx,:);
sw = ones(numel(pars.pose_idx,:),pars.batchsize,'single');
%res1 = caffe('forward', { 0*pose_trans, top_id, top_pose, 0*sw }, 1);
inpq = cat(3,query,repmat(permute(0*pose_trans,[4,3,1,2]),[64 64 1 1]));
res1 = caffe('forward', { 0*pose_trans, top_id, top_pose, 0*sw, inpq }, 1);
pred1 = res1{3};
nstep = 5;
%pred_nstep = zeros([64,64,3,ns,nstep+3]);
pred_nstep = zeros([64,64,3,ns,2*nstep+2]);
pred_nstep(:,:,:,:,1) = ref(:,:,:,1:ns);
pred_nstep(:,:,:,:,2) = out(:,:,:,1:ns);
pred_nstep(:,:,:,:,3) = query(:,:,:,1:ns);
pred_nstep(:,:,:,:,4) = pred1(:,:,:,1:ns);
top_pose_save = top_pose;
query_save = query;
for n = 2:nstep,
    fprintf(1,'frame %d of %d\n', n, nstep);
    inpq = cat(3,query,repmat(permute(pose_trans,[4,3,1,2]),[64 64 1 1]));
    res = caffe('forward', { pose_trans, top_id, top_pose, sw, inpq }, 1);
    top = squeeze(res{1});
    top_pose = top(pars.pose_idx,:);
    pred = res{3};
    pred_nstep(:,:,:,:,3+n) = pred(:,:,:,1:ns);
    
    % update query to get new inpq.
    query = pred;    
end
top_pose = top_pose_save;
query = query_save;
for n = 2:nstep,
    fprintf(1,'frame %d of %d\n', n, nstep);
    inpq = cat(3,query,repmat(permute(-pose_trans,[4,3,1,2]),[64 64 1 1]));
    res = caffe('forward', { -pose_trans, top_id, top_pose, sw, inpq }, 1);
    top = squeeze(res{1});
    top_pose = top(pars.pose_idx,:);
    pred = res{3};
    pred_nstep(:,:,:,:,nstep+2+n) = pred(:,:,:,1:ns);

    % update query to get new inpq.
    query = pred;    
end

tmp = cat(5,pred_nstep(:,:,:,:,1:2), pred_nstep(:,:,:,:,end:-1:nstep+4), ...
          pred_nstep(:,:,:,:,4:nstep+3));
pred_nstep = tmp;
%
vis = permute(pred_nstep, [ 1 4 2 5 3 ]);
vis = reshape(vis, [64*ns 64*size(vis,4) 3]);
%
imagesc(vis); axis off;
%savefig('cars_rot_easy_test','jpeg');

% vcrop
vis_vcrop = zeros(32,size(pred_nstep,2),3,size(pred_nstep,4),size(pred_nstep,5));
for i = 1:size(pred_nstep,4),
    for j = 1:size(pred_nstep,5),
        vis_vcrop(:,:,:,i,j) = pred_nstep(17:48,:,:,i,j);
    end
end
vis = permute(vis_vcrop, [ 1 4 2 5 3 ]);
vis = reshape(vis, [32*ns 64*size(vis,4) 3]);
%
imagesc(vis); axis off;
%%
%savefig('cars_rot_easy_test','jpeg');

