
%% Load model.
addpath('../util');
addpath('../../matlab/caffe');
imgsize = [ 64 64 3 ];
numhid = 1024;
use_gpu = 2;

subdir = 'cars_dis';
init_from = sprintf('results/%s/analogy_cars_dis_a0.01_m0_20151202T230209.mat', subdir);
matcaffe_init_cars_train(use_gpu, ...
                 sprintf('results/%s/analogy_enc_solver.prototxt', subdir), ...
                 sprintf('results/%s/analogy_dec_solver.prototxt', subdir), ...
                 init_from);
model = load(init_from);
pars = model.pars;

%% Load data.
datadir = '../data/cars';
trainids = 1:100;
trainids([55,76,77,82,84,99]) = [];
valids = 101:149;
valids([5,6,17,29,34,40,43,47]) = [];
testids = 150:199;
trainfiles = get_car_files(datadir, trainids);
valfiles = get_car_files(datadir, valids);
testfiles = get_car_files(datadir, testids);
badfiles = [31, 36, 48];
testfiles(badfiles) = [];

trainpairs = combnk(1:numel(trainfiles),2);
valpairs = combnk(1:numel(valfiles),2);
testpairs = combnk(1:numel(testfiles),2);

%% Precompute ID and pose verification pairs.
rng('default');
rng(1);
train_vrf_pairs = zeros(size(trainpairs,1),8);
val_vrf_pairs = zeros(size(valpairs,1),8);
test_vrf_pairs = zeros(size(testpairs,1),8);
disp('train');
for p = 1:size(trainpairs,1),
    id1 = trainpairs(p,1);
    id2 = trainpairs(p,2);
    ept1 = randsample(24, 1);
    apt1 = randsample(4, 1);
    ept2 = randsample(24, 1);
    apt2 = randsample(4, 1);
    ept3 = randsample(24, 1);
    apt3 = randsample(4, 1);
    train_vrf_pairs(p,:) = [ id1, id2, ept1, apt2, ept2, apt2, ept3, apt3 ];
end
%
disp('val');
for p = 1:size(valpairs,1),
    id1 = valpairs(p,1);
    id2 = valpairs(p,2);
    ept1 = randsample(24, 1);
    apt1 = randsample(4, 1);
    ept2 = randsample(24, 1);
    apt2 = randsample(4, 1);
    ept3 = randsample(24, 1);
    apt3 = randsample(4, 1);
    val_vrf_pairs(p,:) = [ id1, id2, ept1, apt2, ept2, apt2, ept3, apt3 ];
end
%
disp('test');
for p = 1:size(testpairs,1),
    id1 = testpairs(p,1);
    id2 = testpairs(p,2);
    ept1 = randsample(24, 1);
    apt1 = randsample(4, 1);
    ept2 = randsample(24, 1);
    apt2 = randsample(4, 1);
    ept3 = randsample(24, 1);
    apt3 = randsample(4, 1);
    test_vrf_pairs(p,:) = [ id1, id2, ept1, apt2, ept2, apt2, ept3, apt3 ];
end
%
save('cars_vrf_pairs.mat', 'train_vrf_pairs', 'val_vrf_pairs', 'test_vrf_pairs');
load('cars_vrf_pairs.mat');
%% Extract all features
% For pose AUC: pose_ref_tr, pose_diff_tr, pose_same_tr,
%               pose_ref_val, pose_diff_val, pose_same_val
%               pose_ref_ts, pose_diff_ts, pose_same_ts
% For iden AUC: iden_ref_tr, iden_diff_tr, iden_same_tr,
%               iden_ref_val, iden_diff_val, iden_same_val
%               iden_ref_ts, iden_diff_ts, iden_same_ts
pose_ref_tr = zeros(pars.numhid,size(train_vrf_pairs,1),'single');
pose_diff_tr = zeros(pars.numhid,size(train_vrf_pairs,1),'single');
pose_same_tr = zeros(pars.numhid,size(train_vrf_pairs,1),'single');
pose_ref_val = zeros(pars.numhid,size(val_vrf_pairs,1),'single');
pose_diff_val = zeros(pars.numhid,size(val_vrf_pairs,1),'single');
pose_same_val = zeros(pars.numhid,size(val_vrf_pairs,1),'single');
pose_ref_ts = zeros(pars.numhid,size(test_vrf_pairs,1),'single');
pose_diff_ts = zeros(pars.numhid,size(test_vrf_pairs,1),'single');
pose_same_ts = zeros(pars.numhid,size(test_vrf_pairs,1),'single');
iden_ref_tr = zeros(pars.numhid,size(train_vrf_pairs,1),'single');
iden_diff_tr = zeros(pars.numhid,size(train_vrf_pairs,1),'single');
iden_same_tr = zeros(pars.numhid,size(train_vrf_pairs,1),'single');
iden_ref_val = zeros(pars.numhid,size(val_vrf_pairs,1),'single');
iden_diff_val = zeros(pars.numhid,size(val_vrf_pairs,1),'single');
iden_same_val = zeros(pars.numhid,size(val_vrf_pairs,1),'single');
iden_ref_ts = zeros(pars.numhid,size(test_vrf_pairs,1),'single');
iden_diff_ts = zeros(pars.numhid,size(test_vrf_pairs,1),'single');
iden_same_ts = zeros(pars.numhid,size(test_vrf_pairs,1),'single');

for p = 1:size(train_vrf_pairs,1),
    fprintf(1,'extracting train features %d of %d\n', ...
            p, size(train_vrf_pairs,1));
    id1 = train_vrf_pairs(p,1);
    id2 = train_vrf_pairs(p,2);
    data1 = load(trainfiles{id1});
    data2 = load(trainfiles{id2});

    ept1 = train_vrf_pairs(p,3);
    apt1 = train_vrf_pairs(p,4);
    ept2 = train_vrf_pairs(p,5);
    apt2 = train_vrf_pairs(p,6);
    ept3 = train_vrf_pairs(p,7);
    apt3 = train_vrf_pairs(p,8);

    ref = imresize(single(data1.im(:,:,:,ept1,apt1))/255., [64 64]);
    pose_same = imresize(single(data2.im(:,:,:,ept1,apt1))/255., [64 64]);
    pose_diff = imresize(single(data2.im(:,:,:,ept2,apt2))/255., [64 64]);
    iden_same = imresize(single(data1.im(:,:,:,ept2,apt2))/255., [64 64]);
    iden_diff = imresize(single(data2.im(:,:,:,ept3,apt3))/255., [64 64]);
    
    ref = repmat(ref,[1 1 1 pars.batchsize]);
    pose_same = repmat(pose_same,[1 1 1 pars.batchsize]);
    pose_diff = repmat(pose_diff,[1 1 1 pars.batchsize]);
    iden_same = repmat(iden_same,[1 1 1 pars.batchsize]);
    iden_diff = repmat(iden_diff,[1 1 1 pars.batchsize]);
    
    res = caffe('forward', { ref; pose_same; pose_diff; 0*ref }, 0);
    hid_pose_same = squeeze(res{1});
    hid_pose_diff = squeeze(res{2});

    res = caffe('forward', { ref; iden_same; iden_diff; 0*ref }, 0);
    hid_iden_same = squeeze(res{1});
    hid_iden_diff = squeeze(res{2});
    hid_ref = squeeze(res{3});

    pose_ref_tr(:,p) = hid_ref(:,1);
    pose_same_tr(:,p) = hid_pose_same(:,1);
    pose_diff_tr(:,p) = hid_pose_diff(:,1);
    iden_ref_tr(:,p) = hid_ref(:,1);
    iden_same_tr(:,p) = hid_iden_same(:,1);
    iden_diff_tr(:,p) = hid_iden_diff(:,1);
end

for p = 1:size(val_vrf_pairs,1),
    fprintf(1,'extracting val features %d of %d\n', ...
            p, size(val_vrf_pairs,1));
    id1 = val_vrf_pairs(p,1);
    id2 = val_vrf_pairs(p,2);
    data1 = load(valfiles{id1});
    data2 = load(valfiles{id2});

    ept1 = val_vrf_pairs(p,3);
    apt1 = val_vrf_pairs(p,4);
    ept2 = val_vrf_pairs(p,5);
    apt2 = val_vrf_pairs(p,6);
    ept3 = train_vrf_pairs(p,7);
    apt3 = train_vrf_pairs(p,8);

    ref = imresize(single(data1.im(:,:,:,ept1,apt1))/255., [64 64]);
    pose_same = imresize(single(data2.im(:,:,:,ept1,apt1))/255., [64 64]);
    pose_diff = imresize(single(data2.im(:,:,:,ept2,apt2))/255., [64 64]);
    iden_same = imresize(single(data1.im(:,:,:,ept2,apt2))/255., [64 64]);
    iden_diff = imresize(single(data2.im(:,:,:,ept3,apt3))/255., [64 64]);

    ref = repmat(ref,[1 1 1 pars.batchsize]);
    pose_same = repmat(pose_same,[1 1 1 pars.batchsize]);
    pose_diff = repmat(pose_diff,[1 1 1 pars.batchsize]);
    iden_same = repmat(iden_same,[1 1 1 pars.batchsize]);
    iden_diff = repmat(iden_diff,[1 1 1 pars.batchsize]);

    res = caffe('forward', { ref; pose_same; pose_diff; 0*ref }, 0);
    hid_pose_same = squeeze(res{1});
    hid_pose_diff = squeeze(res{2});

    res = caffe('forward', { ref; iden_same; iden_diff; 0*ref }, 0);
    hid_iden_same = squeeze(res{1});
    hid_iden_diff = squeeze(res{2});
    hid_ref = squeeze(res{3});

    pose_ref_val(:,p) = hid_ref(:,1);
    pose_same_val(:,p) = hid_pose_same(:,1);
    pose_diff_val(:,p) = hid_pose_diff(:,1);
    iden_ref_val(:,p) = hid_ref(:,1);
    iden_same_val(:,p) = hid_iden_same(:,1);
    iden_diff_val(:,p) = hid_iden_diff(:,1);
end

for p = 1:size(test_vrf_pairs,1),
    fprintf(1,'extracting test features %d of %d\n', ...
            p, size(test_vrf_pairs,1));
    id1 = test_vrf_pairs(p,1);
    id2 = test_vrf_pairs(p,2);
    data1 = load(testfiles{id1});
    data2 = load(testfiles{id2});

    ept1 = test_vrf_pairs(p,3);
    apt1 = test_vrf_pairs(p,4);
    ept2 = test_vrf_pairs(p,5);
    apt2 = test_vrf_pairs(p,6);
    ept3 = train_vrf_pairs(p,7);
    apt3 = train_vrf_pairs(p,8);

    ref = imresize(single(data1.im(:,:,:,ept1,apt1))/255., [64 64]);
    pose_same = imresize(single(data2.im(:,:,:,ept1,apt1))/255., [64 64]);
    pose_diff = imresize(single(data2.im(:,:,:,ept2,apt2))/255., [64 64]);
    iden_same = imresize(single(data1.im(:,:,:,ept2,apt2))/255., [64 64]);
    iden_diff = imresize(single(data2.im(:,:,:,ept3,apt3))/255., [64 64]);

    ref = repmat(ref,[1 1 1 pars.batchsize]);
    pose_same = repmat(pose_same,[1 1 1 pars.batchsize]);
    pose_diff = repmat(pose_diff,[1 1 1 pars.batchsize]);
    iden_same = repmat(iden_same,[1 1 1 pars.batchsize]);
    iden_diff = repmat(iden_diff,[1 1 1 pars.batchsize]);

    res = caffe('forward', { ref; pose_same; pose_diff; 0*ref }, 0);
    hid_pose_same = squeeze(res{1});
    hid_pose_diff = squeeze(res{2});

    res = caffe('forward', { ref; iden_same; iden_diff; 0*ref }, 0);
    hid_iden_same = squeeze(res{1});
    hid_iden_diff = squeeze(res{2});
    hid_ref = squeeze(res{3});

    pose_ref_ts(:,p) = hid_ref(:,1);
    pose_same_ts(:,p) = hid_pose_same(:,1);
    pose_diff_ts(:,p) = hid_pose_diff(:,1);
    iden_ref_ts(:,p) = hid_ref(:,1);
    iden_same_ts(:,p) = hid_iden_same(:,1);
    iden_diff_ts(:,p) = hid_iden_diff(:,1);
end

%% Compute pose and ID AUC.
%addpath('/mnt/neocortex/library/liblinear-1.8/matlab');
addpath('/mnt/neocortex2/scratch/yeezhang/liblinear/matlab');
Clist = [ 0.0001, 0.001, 0.01, 0.1, 1 ];

%% Pose -> Pose.
Ytrain = [ zeros(size(pose_ref_tr,2),1); ones(size(pose_ref_tr,2), 1) ];
Ftrain = [ abs(pose_ref_tr - pose_same_tr), abs(pose_ref_tr - pose_diff_tr) ];
Yval = [ zeros(size(pose_ref_val,2),1); ones(size(pose_ref_val,2), 1) ];
Fval = [ abs(pose_ref_val - pose_same_val), abs(pose_ref_val - pose_diff_val) ];
Ytest = [ zeros(size(pose_ref_ts,2),1); ones(size(pose_ref_ts,2), 1) ];
Ftest = [ abs(pose_ref_ts - pose_same_ts), abs(pose_ref_ts - pose_diff_ts) ];
[acc_val, acc_test, best_model] = ...
    eval_cls(Clist, abs(Ftrain(pars.pose_idx,:)), Ytrain, abs(Fval(pars.pose_idx,:)), Yval, abs(Ftest(pars.pose_idx,:)), Ytest);
fprintf(1, 'Pose -> Pose: acc_val: %.4g, acc_test: %.4g\n', acc_val, acc_test);

%
tmp = Ytest;
tmp(tmp==0) = -1;
tmp = double(tmp);
inst = sparse(double(abs(Ftest(pars.pose_idx,:))));
best_model.Label(1) = -1;
auc = svm_plotroc(tmp, inst, best_model);
fprintf(1, 'Pose -> Pose SVM auc: %.4g\n', auc);

%% ID -> Pose
[acc_val, acc_test, best_model] = ...
    eval_cls(Clist, abs(Ftrain(pars.id_idx,:)), Ytrain, abs(Fval(pars.id_idx,:)), Yval, abs(Ftest(pars.id_idx,:)), Ytest);
fprintf(1, 'Id -> Pose: acc_val: %.4g, acc_test: %.4g\n', acc_val, acc_test);

%
tmp = Ytest;
tmp(tmp==0) = -1;
tmp = double(tmp);
inst = sparse(double(abs(Ftest(pars.id_idx,:))));
best_model.Label(1) = -1;
auc = svm_plotroc(tmp, inst, best_model);
fprintf(1, 'Id -> Pose SVM auc: %.4g\n', auc);

%% All -> Pose
[acc_val, acc_test, best_model] = ...
    eval_cls(Clist, abs(Ftrain), Ytrain, abs(Fval), Yval, abs(Ftest), Ytest);
fprintf(1, 'All -> Pose: acc_val: %.4g, acc_test: %.4g\n', acc_val, acc_test);

%
tmp = Ytest;
tmp(tmp==0) = -1;
tmp = double(tmp);
inst = sparse(double(abs(Ftest)));
best_model.Label(1) = -1;
auc = svm_plotroc(tmp, inst, best_model);
fprintf(1, 'All -> Pose SVM auc: %.4g\n', auc);

%% ID -> ID
Ytrain = [ zeros(size(iden_ref_tr,2),1); ones(size(iden_ref_tr,2), 1) ];
Ftrain = [ abs(iden_ref_tr - iden_same_tr), abs(iden_ref_tr - iden_diff_tr) ];
Yval = [ zeros(size(iden_ref_val,2),1); ones(size(iden_ref_val,2), 1) ];
Fval = [ abs(iden_ref_val - iden_same_val), abs(iden_ref_val - iden_diff_val) ];
Ytest = [ zeros(size(iden_ref_ts,2),1); ones(size(iden_ref_ts,2), 1) ];
Ftest = [ abs(iden_ref_ts - iden_same_ts), abs(iden_ref_ts - iden_diff_ts) ];
[acc_val, acc_test, best_model] = ...
    eval_cls(Clist, abs(Ftrain(pars.id_idx,:)), Ytrain, abs(Fval(pars.id_idx,:)), Yval, abs(Ftest(pars.id_idx,:)), Ytest);
fprintf(1, 'ID -> ID: acc_val: %.4g, acc_test: %.4g\n', acc_val, acc_test);

%
tmp = Ytest;
tmp(tmp==0) = -1;
tmp = double(tmp);
inst = sparse(double(abs(Ftest(pars.id_idx,:))));
best_model.Label(1) = -1;
auc = svm_plotroc(tmp, inst, best_model);
fprintf(1, 'ID -> ID SVM auc: %.4g\n', auc);

%% Pose -> ID
[acc_val, acc_test, best_model] = ...
    eval_cls(Clist, abs(Ftrain(pars.pose_idx,:)), Ytrain, abs(Fval(pars.pose_idx,:)), Yval, abs(Ftest(pars.pose_idx,:)), Ytest);
fprintf(1, 'Pose -> ID: acc_val: %.4g, acc_test: %.4g\n', acc_val, acc_test);

tmp = Ytest;
tmp(tmp==0) = -1;
tmp = double(tmp);
inst = sparse(double(abs(Ftest(pars.pose_idx,:))));
best_model.Label(1) = -1;
auc = svm_plotroc(tmp, inst, best_model);
fprintf(1, 'Pose -> ID SVM auc: %.4g\n', auc);

%% All -> Id
[acc_val, acc_test, best_model] = ...
    eval_cls(Clist, abs(Ftrain), Ytrain, abs(Fval), Yval, abs(Ftest), Ytest);
fprintf(1, 'All -> ID: acc_val: %.4g, acc_test: %.4g\n', acc_val, acc_test);

tmp = Ytest;
tmp(tmp==0) = -1;
tmp = double(tmp);
inst = sparse(double(abs(Ftest)));
best_model.Label(1) = -1;
auc = svm_plotroc(tmp, inst, best_model);
fprintf(1, 'All -> ID SVM auc: %.4g\n', auc);

%% Test set pixel prediction error.

% Iterate over all pairs of ids.
cost_img = 0;
cost_hid = 0;
expname = 'dis';
for p = 1:size(test_vrf_pairs,1),
    id1 = test_vrf_pairs(p,1);
    id2 = test_vrf_pairs(p,2);
    data1 = load(testfiles{id1});
    data2 = load(testfiles{id2});

    fprintf('test analogy pair %d of %d, [%d,%d]\n', ...
            p, size(test_vrf_pairs,1), id1, id2);

    ept1 = test_vrf_pairs(p,3);
    apt1 = test_vrf_pairs(p,4);
    ept2 = test_vrf_pairs(p,5);
    apt2 = test_vrf_pairs(p,6);

    if strcmp(expname,'dis'),
        img1 = imresize(single(data1.im(:,:,:,ept1,apt1))/255., [64 64]);
        img2 = imresize(single(data2.im(:,:,:,ept2,apt2))/255., [64 64]);
        img3 = imresize(single(data1.im(:,:,:,ept2,apt2))/255., [64 64]);
        img4 = img3;
        img_target = img3;
    elseif strcmp(expname,'mul'),
        img1 = imresize(single(data2.im(:,:,:,ept1,apt1))/255., [64 64]);
        img2 = imresize(single(data2.im(:,:,:,ept2,apt2))/255., [64 64]);
        img3 = imresize(single(data1.im(:,:,:,ept1,apt1))/255., [64 64]);
        img4 = imresize(single(data1.im(:,:,:,ept2,apt2))/255., [64 64]);
        img_target = img4;
    end

    img1 = repmat(img1, [1,1,1,pars.batchsize]);
    img2 = repmat(img2, [1,1,1,pars.batchsize]);
    img3 = repmat(img3, [1,1,1,pars.batchsize]);
    img4 = repmat(img4, [1,1,1,pars.batchsize]);

    res = caffe('forward', { img1; img2; img3; img4 }, 0);
    hid2 = squeeze(res{1});
    hid3 = squeeze(res{2});
    hid1 = squeeze(res{3});
    hid4 = squeeze(res{4});

    if strcmp(expname,'dis'),
        top = hid1;
        top(pars.pose_idx,:) = hid2(pars.pose_idx,:);
        trans = 0*top;
        sw = 0*top;
        hid_target = hid3;
    elseif strcmp(expname, 'mul')
        top = hid3;
        trans = hid2 - hid1;
        sw = ones(size(top),'single');
        hid_target = hid4;
    end
    
    trans = trans(pars.pose_idx,:);
    top_pose = top(pars.pose_idx,:);
    top_id = top(pars.id_idx,:);

    %res = caffe('forward', { trans; top; sw; }, 1);
    res = caffe('forward', { trans; top_id; top_pose; sw; }, 1);
    pred = squeeze(res{3});

    err_hid = top(:,1) - hid_target(:,1);
    err_img = pred(:,:,:,1) - img_target(:,:,:,1);

    cost_hid = cost_hid + sqrt(err_hid(:)'*err_hid(:));
    cost_img = cost_img + sqrt(err_img(:)'*err_img(:));
end
cost_hid = cost_hid / size(testpairs,1);
cost_img = cost_img / size(testpairs,1);

fprintf(1,'Analogy cost_hid: %.4g, cost_img: %.4g\n', cost_hid, cost_img);

%% Single pair testing.
cars1 = [ 4, 5, 13, 7, 8 ];
cars2 = [ 3, 9, 15, 11, 16 ];
poses1 = [ 7 1; ...
           10 2; ...
           5 3; ...
           1 1; ...
           4 4 ];
poses2 = [ 2 3; ...
           7 1; ...
           18 4; ...
           9 3; ...
           18 2 ];
ns = length(cars1);

refs = zeros(64,64,3,ns);
outs = zeros(64,64,3,ns);
targets = zeros(64,64,3,ns);
preds = zeros(64,64,3,ns);
for s = 1:ns,
    data1 = load(testfiles{cars1(s)});
    data2 = load(testfiles{cars2(s)});
    pose1 = poses1(s,:);
    pose2 = poses2(s,:);

    img_ref = single(data1.im(:,:,:,pose1(1),pose1(2))) / 255.;
    img_out = single(data2.im(:,:,:,pose2(1),pose2(2))) / 255.;
    img_target = single(data2.im(:,:,:,pose1(1),pose1(2))) / 255.;

    Xin = cellfun(@(x)imresize(x,[64 64]), ...
            { img_ref; img_out; img_target; img_target }, 'UniformOutput', false);
    result = caffe('forward', Xin, 0);
    hid_out = squeeze(result{1});
    hid_query = squeeze(result{2});
    hid_ref = squeeze(result{3});
    hid_target = squeeze(result{4});

    iden_idx = 1:512;
    pose_idx = 513:640;

    top = hid_query;
    top(iden_idx,:) = hid_out(iden_idx,:);
    top(pose_idx,:) = hid_ref(pose_idx,:);
    trans = 0*top;
    sw = 0*top;
    trans = trans(pars.pose_idx,:);
    top_id = top(pars.id_idx,:);
    top_pose = top(pars.pose_idx,:);
    res = caffe('forward', { trans; top_id; top_pose; sw }, 1);
    pred = res{3}(:,:,:,1);
    
    refs(:,:,:,s) = Xin{1}(:,:,:,1);
    outs(:,:,:,s) = Xin{2}(:,:,:,1);
    targets(:,:,:,s) = Xin{3}(:,:,:,1);
    preds(:,:,:,s) = pred(:,:,:,1);    
end

fix = @(x) reshape(permute(x, [ 1 4 2 3 ]), [ 64*ns, 64, 3]);
vis = cat(2, fix(refs), fix(outs), fix(targets), fix(preds));
imagesc(vis); axis off;
savefig('car_disentangled','pdf');

%%
%% Single pair testing.
cars1 = [ 4, 8, 13 ];
cars2 = [ 3, 16, 15 ];
poses1 = [ 7 1; ...
           4 4; ...
           5 3 ];
poses2 = [ 2 3; ...
           18 2; ...
           18 4 ];
ns = length(cars1);

refs = zeros(64,64,3,ns);
outs = zeros(64,64,3,ns);
targets = zeros(64,64,3,ns);
preds = zeros(64,64,3,ns);
for s = 1:ns,
    data1 = load(testfiles{cars1(s)});
    data2 = load(testfiles{cars2(s)});
    pose1 = poses1(s,:);
    pose2 = poses2(s,:);

    img_ref = single(data1.im(:,:,:,pose1(1),pose1(2))) / 255.;
    img_out = single(data2.im(:,:,:,pose2(1),pose2(2))) / 255.;
    img_target = single(data2.im(:,:,:,pose1(1),pose1(2))) / 255.;

    Xin = cellfun(@(x)imresize(x,[64 64]), ...
            { img_ref; img_out; img_target; img_target }, 'UniformOutput', false);
    result = caffe('forward', Xin, 0);
    hid_out = squeeze(result{1});
    hid_query = squeeze(result{2});
    hid_ref = squeeze(result{3});
    hid_target = squeeze(result{4});

    iden_idx = 1:512;
    pose_idx = 513:640;

    top = hid_query;
    top(iden_idx,:) = hid_out(iden_idx,:);
    top(pose_idx,:) = hid_ref(pose_idx,:);
    trans = 0*top;
    sw = 0*top;
    %{
    trans = trans(pars.pose_idx,:);
    top_id = top(pars.id_idx,:);
    top_pose = top(pars.pose_idx,:);
    res = caffe('forward', { trans; top_id; top_pose; sw }, 1);
    %}
    res = caffe('forward', { trans; top; sw }, 1);
    pred = res{3}(:,:,:,1);
    
    refs(:,:,:,s) = Xin{1}(:,:,:,1);
    outs(:,:,:,s) = Xin{2}(:,:,:,1);
    targets(:,:,:,s) = Xin{3}(:,:,:,1);
    preds(:,:,:,s) = pred(:,:,:,1);    
end

%%
fix = @(x) reshape(permute(x, [ 1 4 2 3 ]), [ 64*ns, 64, 3]);
vis = cat(2, fix(refs), fix(outs), fix(targets), fix(preds));
vis2 = cat(1,vis(15:45,:,:),vis(75:115,:,:),vis(145:end-10,:,:));
imagesc(vis2); axis off;
savefig('car_disentangled_small','jpeg');

