function eval_predictions(use_gpu, expname)

addpath('../util');
addpath('../../matlab/caffe');
imgsize = [ 48 48 ];
numhid = 512;

if ~exist('gpu', 'var'),
    gpu = -1;
end

if ~exist('expname', 'var')
    expname = 'deep';
end

if strcmp(expname,'add')
    init_from = 'results/shapes_add/analogy_shapes_add.mat';
elseif strcmp(expname,'mul')
    init_from = 'results/shapes_mul/analogy_shapes_mul.mat';
elseif strcmp(expname,'deep')
    init_from = 'results/shapes_deep/analogy_shapes_deep.mat';
else
    error('uknown expname');
end

subdir = sprintf('shapes_%s', expname);
model = load(init_from);
pars = model.pars;
matcaffe_init_shapes_train(use_gpu, ...
                  sprintf('results/%s/analogy_enc_solver.prototxt', subdir), ...
                  sprintf('results/%s/analogy_dec_solver.prototxt', subdir), ...
                  init_from);
caffe('set_phase_test');

load('../data/shapes48.mat');
data = M;
data = single(data);

% Make val pairs.
rng('default');
rng(1);
[tmp1,tmp2] = find(~pars.pairs);
valtestpairs = [ tmp1, tmp2 ];
Rval = randsample(size(valtestpairs,1), 30);
Rtest = setdiff(1:size(valtestpairs,1), Rval);
valpairs = valtestpairs(Rval,:);
testpairs = valtestpairs(Rtest,:);

%%
rng('default');
rng(1);
nstep_rot = 12;
cost_nstep_rot = zeros(size(testpairs,1),nstep_rot);
nstep_scale = 5;
cost_nstep_scale = zeros(size(testpairs,1),nstep_scale);
nstep_trans = 5;
cost_nstep_trans = zeros(size(testpairs,1),nstep_trans);
test_benchmark = [];
for p = 1:size(testpairs,1),
    id1 = testpairs(p,1);
    id2 = testpairs(p,2);

    %% Rotation.
    [ batch_data, labels_rot ] = ...
        sample_analogy_shapes(data,id1,id2,1,pars);
    ref_rot = batch_data.X1;
    out_rot = batch_data.X2;
    query_rot = batch_data.X3;
    target_rot = batch_data.X4;
    res = caffe('forward', { ref_rot; out_rot; query_rot; target_rot }, 0);
    hid_out = res{1};
    hid_query = res{2};
    hid_ref = res{3};
    pred_nstep_rot = analogy_repeated(expname,pars,hid_ref,hid_out,hid_query,nstep_rot);
    errsq = (pred_nstep_rot - batch_data.X4(:,:,:,1:nstep_rot)).^2;
    errsq = squeeze(sum(sum(sum(errsq)))) / nstep_rot;
    cost_nstep_rot(p,:) = errsq;

    %% Scaling.
    [ batch_data, labels_scale ] = ...
        sample_analogy_shapes(data,id1,id2,2,pars);
    ref_scale = batch_data.X1;
    out_scale = batch_data.X2;
    query_scale = batch_data.X3;
    target_scale = batch_data.X4;
    res = caffe('forward', { ref_scale; out_scale; query_scale; target_scale }, 0);
    hid_out = res{1};
    hid_query = res{2};
    hid_ref = res{3};
    pred_nstep_scale = analogy_repeated(expname,pars,hid_ref,hid_out,hid_query,nstep_scale);
    errsq = (pred_nstep_scale - batch_data.X4(:,:,:,1:nstep_scale)).^2;
    errsq = squeeze(sum(sum(sum(errsq)))) / nstep_scale;
    cost_nstep_scale(p,:) = errsq;

    %% Translation.
    [ batch_data, labels_xpos ] = ...
        sample_analogy_shapes(data,id1,id2,3,pars);
    ref_trans = batch_data.X1;
    out_trans = batch_data.X2;
    query_trans = batch_data.X3;
    target_trans = batch_data.X4;
    res = caffe('forward', { ref_trans; out_trans; query_trans; target_trans }, 0);
    hid_out = res{1};
    hid_query = res{2};
    hid_ref = res{3};
    pred_nstep_trans = analogy_repeated(expname,pars,hid_ref,hid_out,hid_query,nstep_trans);
    errsq = (pred_nstep_trans - batch_data.X4(:,:,:,1:nstep_trans)).^2;
    errsq = squeeze(sum(sum(sum(errsq)))) / nstep_trans;
    cost_nstep_trans(p,:) = errsq / 2;

    [ batch_data, labels_ypos ] = ...
        sample_analogy_shapes(data,id1,id2,4,pars);
    ref_trans = batch_data.X1;
    out_trans = batch_data.X2;
    query_trans = batch_data.X3;
    target_trans = batch_data.X4;
    res = caffe('forward', { ref_trans; out_trans; query_trans; target_trans }, 0);
    hid_out = res{1};
    hid_query = res{2};
    hid_ref = res{3};
    pred_nstep_trans = analogy_repeated(expname,pars,hid_ref,hid_out,hid_query,nstep_trans);
    errsq = (pred_nstep_trans - batch_data.X4(:,:,:,1:nstep_trans)).^2;
    errsq = squeeze(sum(sum(sum(errsq)))) / nstep_trans;
    cost_nstep_trans(p,:) = cost_nstep_trans(p,:) + errsq' / 2;

    %% Store benchmark info.
    benchmark_info = struct('id1', id1, ...
                            'id2', id2, ...
                            'rot', labels_rot, ...
                            'scale', labels_scale, ...
                            'xpos', labels_xpos, ...
                            'ypos', labels_ypos);
    test_benchmark = [ test_benchmark, benchmark_info ];

    %% Visualization sanity check.
    if mod(p,10)==0,
        fprintf(1,'test batch %d of %d\n', p, size(testpairs,1));
        numex = 3;
        numstep_show = 4;
        vis = zeros(48,48,3,numex,3+numstep_show);
        % rot
        vis(:,:,:,1,1) = ref_rot(:,:,:,1);
        vis(:,:,:,1,2) = out_rot(:,:,:,1);
        vis(:,:,:,1,3) = query_rot(:,:,:,1);
        vis(:,:,:,1,4) = pred_nstep_rot(:,:,:,2);
        vis(:,:,:,1,5) = pred_nstep_rot(:,:,:,3);
        vis(:,:,:,1,6) = pred_nstep_rot(:,:,:,4);
        vis(:,:,:,1,7) = pred_nstep_rot(:,:,:,5);
        % scale
        vis(:,:,:,2,1) = ref_scale(:,:,:,1);
        vis(:,:,:,2,2) = out_scale(:,:,:,1);
        vis(:,:,:,2,3) = query_scale(:,:,:,1);
        vis(:,:,:,2,4) = pred_nstep_scale(:,:,:,2);
        vis(:,:,:,2,5) = pred_nstep_scale(:,:,:,3);
        vis(:,:,:,2,6) = pred_nstep_scale(:,:,:,4);
        vis(:,:,:,2,7) = pred_nstep_scale(:,:,:,5);
        % trans
        vis(:,:,:,3,1) = ref_trans(:,:,:,1);
        vis(:,:,:,3,2) = out_trans(:,:,:,1);
        vis(:,:,:,3,3) = query_trans(:,:,:,1);
        vis(:,:,:,3,4) = pred_nstep_trans(:,:,:,2);
        vis(:,:,:,3,5) = pred_nstep_trans(:,:,:,3);
        vis(:,:,:,3,6) = pred_nstep_trans(:,:,:,4);
        vis(:,:,:,3,7) = pred_nstep_trans(:,:,:,5);

        vis = permute(vis, [ 1 4 2 5 3 ]);
        vis = reshape(vis, [ 48*numex, 48*7, 3 ]);
        imagesc(vis); axis off;
        drawnow;
        %pause;
    end
end

%% Save results.
fname_save = sprintf('shapes_rot_exp_hard_%s_%s.mat', ...
                     expname, datestr(now,30));
save(fname_save, 'cost_nstep_rot', 'cost_nstep_scale', ...
     'cost_nstep_trans', 'test_benchmark');

%% Print results.
fprintf(1,'Rotation error:\n');
disp(mean(cost_nstep_rot,1));
figure(1);
subplot(1,3,1);
mn_rot = mean(cost_nstep_rot,1);
scatter(1:nstep_rot,mn_rot);
title('Rotation error');
%
fprintf(1,'Scaling error:\n');
disp(mean(cost_nstep_scale,1));
subplot(1,3,2);
scatter(1:nstep_scale,mean(cost_nstep_scale,1));
title('Scaling error');
%
fprintf(1,'Translation error:\n');
disp(mean(cost_nstep_trans,1));
subplot(1,3,3);
scatter(1:nstep_trans,mean(cost_nstep_trans,1));
title('Translation error');
