"""Builds a version of the GoogLeNet (Inception) network for ILSVRC12."""

# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
# pylint: disable=no-self-use
# pylint: disable=too-few-public-methods

from caffe.proto import caffe_pb2
from caffe import Net

from google.protobuf import text_format

# pylint: disable=invalid-name
# pylint: disable=no-member
LayerType = caffe_pb2.LayerParameter.LayerType
PoolMethod = caffe_pb2.PoolingParameter.PoolMethod
DBType = caffe_pb2.DataParameter.DB
EltwiseOp = caffe_pb2.EltwiseParameter.EltwiseOp
# pylint: enable=invalid-name
# pylint: enable=no-member


class DisentangleBuilder(object):

  def __init__(self, training_batch_size, testing_batch_size, **kwargs):
    self.training_batch_size = training_batch_size
    self.testing_batch_size = testing_batch_size
    self.other_args = kwargs


  def _make_inception(self, network, x1x1, x3x3r, x3x3, x5x5r, x5x5, proj,
                      name_generator):
    """Make Inception submodule."""

    layers = []

    split = self._make_split_layer(network)
    layers.append(split)

    context1 = self._make_conv_layer(network, kernel_size=1, num_output=x1x1,
                                     bias_value=0.1)
    layers.append(context1)

    relu1 = self._make_relu_layer(network)
    layers.append(relu1)

    context2a = self._make_conv_layer(network, kernel_size=1, num_output=x3x3r,
                                      bias_value=0.1)
    layers.append(context2a)

    relu2a = self._make_relu_layer(network)
    layers.append(relu2a)

    context2b = self._make_conv_layer(network, kernel_size=3, num_output=x3x3,
                                      pad=1)
    layers.append(context2b)

    relu2b = self._make_relu_layer(network)
    layers.append(relu2b)

    context3a = self._make_conv_layer(network, kernel_size=1, num_output=x5x5r,
                                      bias_value=0.1)
    layers.append(context3a)

    relu3a = self._make_relu_layer(network)
    layers.append(relu3a)

    context3b = self._make_conv_layer(network, kernel_size=5, num_output=x5x5,
                                      pad=2)
    layers.append(context3b)

    relu3b = self._make_relu_layer(network)
    layers.append(relu3b)

    context4a = self._make_maxpool_layer(network, kernel_size=3)
    layers.append(context4a)

    relu4a = self._make_relu_layer(network)
    layers.append(relu4a)

    context4b = self._make_conv_layer(network, kernel_size=1, num_output=proj,
                                      pad=1, bias_value=0.1)
    layers.append(context4b)

    relu4b = self._make_relu_layer(network)
    layers.append(relu4b)

    concat = self._make_concat_layer(network)
    layers.append(concat)

    connections = [
      (split.name, (split.top, context1.bottom)),
      (split.name, (split.top, context2a.bottom)),
      (split.name, (split.top, context3a.bottom)),
      (split.name, (split.top, context4a.bottom)),
      (context2a.name,
          (context2a.top, relu2a.bottom, relu2a.top, context2b.bottom)),
      (context3a.name,
          (context3a.top, relu3a.bottom, relu3a.top, context3b.bottom)),
      (context4a.name,
          (context4a.top, relu4a.bottom, relu4a.top, context4b.bottom)),
      (context1.name, (context1.top, relu1.bottom, relu1.top, concat.bottom)),
      (context2b.name,
          (context2b.top, relu2b.bottom, relu2b.top, concat.bottom)),
      (context3b.name,
          (context3b.top, relu3b.bottom, relu3b.top, concat.bottom)),
      (context4b.name,
          (context4b.top, relu4b.bottom, relu4b.top, concat.bottom)),
    ]

    for connection in connections:
      self._tie(connection, name_generator)

    return layers

  def _make_sum_layer(self, network):
    layer = network.layers.add()
    layer.name = 'sum'
    layer.type = LayerType.Value('ELTWISE')
    params = layer.eltwise_param
    params.operation = EltwiseOp.Value('SUM')
    return layer

  def _make_upsampling_layer(self, network, stride):
    layer = network.layers.add()
    layer.name = 'upsample'
    layer.type = LayerType.Value('UPSAMPLING')
    params = layer.upsampling_param
    params.kernel_size = stride
    return layer

  def _make_folding_layer(self, network, channels, height, width, prefix=''):
    layer = network.layers.add()
    layer.name = '%s_folding' % prefix
    layer.type = LayerType.Value('FOLDING')
    params = layer.folding_param
    params.channels_folded = channels
    params.height_folded = height
    params.width_folded = width
    return layer

  def _make_conv_layer(self, network, kernel_size, num_output, stride=1, pad=0,
                       bias_value=0.1, shared_name=None, wtype='xavier', std=0.01):
    """Make convolution layer."""

    layer = network.layers.add()
    layer.name = 'conv_%dx%d_%d' % (kernel_size, kernel_size, stride)

    layer.type = LayerType.Value('CONVOLUTION')
    params = layer.convolution_param
    params.num_output = num_output
    params.kernel_size = kernel_size
    params.stride = stride
    params.pad = pad
    weight_filler = params.weight_filler
    weight_filler.type = wtype
    if weight_filler.type == 'gaussian':
      weight_filler.mean = 0
      weight_filler.std = std
    bias_filler = params.bias_filler
    bias_filler.type = 'constant'
    bias_filler.value = bias_value

    layer.blobs_lr.append(1)
    layer.blobs_lr.append(2)

    layer.weight_decay.append(1)
    layer.weight_decay.append(0)

    if shared_name:
      layer.param.append('%s_w' % shared_name)
      layer.param.append('%s_b' % shared_name)

    return layer

  def _make_maxpool_layer(self, network, kernel_size, stride=1):
    """Make max pooling layer."""

    layer = network.layers.add()
    layer.name = 'maxpool_%dx%d_%d' % (kernel_size, kernel_size, stride)

    layer.type = LayerType.Value('POOLING')
    params = layer.pooling_param
    params.pool = PoolMethod.Value('MAX')
    params.kernel_size = kernel_size
    params.stride = stride

    return layer

  def _make_avgpool_layer(self, network, kernel_size, stride=1):
    """Make average pooling layer."""

    layer = network.layers.add()
    layer.name = 'avgpool_%dx%d_%d' % (kernel_size, kernel_size, stride)

    layer.type = LayerType.Value('POOLING')
    params = layer.pooling_param
    params.pool = PoolMethod.Value('AVE')
    params.kernel_size = kernel_size
    params.stride = stride

    return layer

  def _make_lrn_layer(self, network, name='lrn'):
    """Make local response normalization layer."""

    layer = network.layers.add()
    layer.name = name

    layer.type = LayerType.Value('LRN')
    params = layer.lrn_param
    params.local_size = 5
    params.alpha = 0.0001
    params.beta = 0.75

    return layer

  def _make_concat_layer(self, network):
    """Make depth concatenation layer."""

    layer = network.layers.add()
    layer.name = 'concat'

    layer.type = LayerType.Value('CONCAT')

    return layer

  def _make_dropout_layer(self, network, dropout_ratio=0.5):
    """Make dropout layer."""

    layer = network.layers.add()
    layer.name = 'dropout'

    layer.type = LayerType.Value('DROPOUT')
    params = layer.dropout_param
    params.dropout_ratio = dropout_ratio

    return layer

  def _make_inner_product_layer(self, network, num_output, weight_lr=1,
                                bias_lr=2, bias_value=0.1, prefix=''):
    """Make inner product layer."""

    layer = network.layers.add()
    layer.name = '%sinner_product' % prefix

    layer.type = LayerType.Value('INNER_PRODUCT')
    params = layer.inner_product_param
    params.num_output = num_output
    weight_filler = params.weight_filler
    weight_filler.type = 'xavier'
    #weight_filler.type = 'gaussian'
    #weight_filler.value = 0.001
    bias_filler = params.bias_filler
    bias_filler.type = 'constant'
    #bias_filler.value = bias_value
    bias_filler.value = 0.1

    layer.blobs_lr.append(weight_lr)
    layer.blobs_lr.append(bias_lr)

    layer.weight_decay.append(1)
    layer.weight_decay.append(0)

    return layer

  def _make_split_layer(self, network):
    """Make split layer."""

    layer = network.layers.add()
    layer.name = 'split'

    layer.type = LayerType.Value('SPLIT')

    return layer

  def _make_relu_layer(self, network):
    """Make ReLU layer."""

    layer = network.layers.add()
    layer.name = 'relu'

    layer.type = LayerType.Value('RELU')

    return layer

  def _make_sigmoid_layer(self, network):
    """Make sigmoid layer."""

    layer = network.layers.add()
    layer.name = 'relu'

    layer.type = LayerType.Value('SIGMOID')

    return layer

  def _tie(self, layers, name_generator):
    """Generate a named connection between layer endpoints."""

    name = '%s_%d' % (layers[0], name_generator.next())
    for layer in layers[1]:
      layer.append(name)

  def _connection_name_generator(self):
    """Generate a unique id."""

    index = 0
    while True:
      yield index
      index += 1

  def _make_column(self, network, name_generator, emb_prefix=''):
    layers = []

    # 48 x 48 x 1 input
    conv1 = self._make_conv_layer(network, kernel_size=5,
                                  num_output=64, pad=2, bias_value=0.1)
    layers.append(conv1)

    # 48 x 48 x 64 input
    maxpool1 = self._make_maxpool_layer(network, kernel_size=2, stride=2)
    layers.append(maxpool1)

    # 24 x 24 x 64 input
    relu1 = self._make_relu_layer(network)
    layers.append(relu1)

    lrn1 = self._make_lrn_layer(network)
    layers.append(lrn1)

    # 24 x 24 x 64 input (replace inception by a simple conv)
    conv2 = self._make_conv_layer(network, kernel_size=5,
                                  num_output=96, pad=2, bias_value=0.1)
    layers.append(conv2)

    # 24 x 24 x 96 input
    maxpool2 = self._make_maxpool_layer(network, kernel_size=2, stride=2)
    layers.append(maxpool2)

    # 12 x 12 x 96 input
    relu2 = self._make_relu_layer(network)
    layers.append(relu2)
    
    lrn2 = self._make_lrn_layer(network)
    layers.append(lrn2)
    
     # 12 x 12 x 96 input (replace inception by a simple conv)
    conv3 = self._make_conv_layer(network, kernel_size=5,
                                  num_output=128, pad=2, bias_value=0.1)
    layers.append(conv3)

    # 12 x 12 x 128 input
    maxpool3 = self._make_maxpool_layer(network, kernel_size=3, stride=2)
    layers.append(maxpool3)

    # 6 x 6 x 128 input
    relu3 = self._make_relu_layer(network)
    layers.append(relu3)

    # Projection to embedding layer.
    fc1 = self._make_inner_product_layer(network, num_output=512)
    layers.append(fc1)

    relu4 = self._make_relu_layer(network)
    layers.append(relu4)

    fc2 = self._make_inner_product_layer(network, num_output=64)
    layers.append(fc2)

    connections = [
      (conv1.name, (conv1.top, maxpool1.bottom)),
      (maxpool1.name, (maxpool1.top, relu1.bottom, relu1.top, lrn1.bottom)),
      (lrn1.name, (lrn1.top, conv2.bottom)),
      (conv2.name, (conv2.top, maxpool2.bottom)),
      (maxpool2.name, (maxpool2.top, relu2.bottom, relu2.top, lrn2.bottom)),
      (lrn2.name, (lrn2.top, conv3.bottom)),
      (conv3.name, (conv3.top, maxpool3.bottom)),
      (maxpool3.name, (maxpool3.top, relu3.bottom, relu3.top, fc1.bottom)),
      (fc1.name, (fc1.top, relu4.bottom, relu4.top, fc2.bottom)),
    ]
    for connection in connections:
      self._tie(connection, name_generator)
    return layers

  def _make_reconstruction_recurrent(self, network, name_generator, numstep):
    layers = []
    wtype = 'xavier'
    #wtype = 'gaussian'
    wf_std = 0.01

    # ID to hid.
    id_ip1 = self._make_inner_product_layer(network, num_output=256)
    id_ip1.bottom.append('id')
    layers.append(id_ip1)

    id_relu1 = self._make_relu_layer(network)
    layers.append(id_relu1)

    # Var to hid.
    var_ip1 = self._make_inner_product_layer(network, num_output=256)
    var_ip1.bottom.append('var')
    layers.append(var_ip1)

    var_relu1 = self._make_relu_layer(network)
    layers.append(var_relu1)

    # Concat.
    concat = self._make_concat_layer(network)
    layers.append(concat)

    # Another FC layer.
    ip2 = self._make_inner_product_layer(network, num_output=512)
    layers.append(ip2)

    relu2 = self._make_relu_layer(network)
    layers.append(relu2)

    # Another FC layer.
    ip3 = self._make_inner_product_layer(network, num_output=512)
    layers.append(ip3)
    
    relu3 = self._make_relu_layer(network)
    layers.append(relu3)

    # Another FC layer.
    ip4 = self._make_inner_product_layer(network, num_output=8*8*32)
    layers.append(ip4)

    relu4 = self._make_relu_layer(network)
    layers.append(relu4)

    # Reshape into tensor.
    folding = self._make_folding_layer(network, 32, 8, 8)
    layers.append(folding)

    conv1 = []
    conv1_split = []
    deconv1 = []
    unpool1_relu = []
    conv2 = []
    conv2_split = []
    deconv2 = []
    upsample2 = []
    unpool2_relu = []
    conv3 = []
    deconv3 = []
    upsample3 = []
    unpool3_relu = []

    # unpooling layer -> 16x16x32.
    upsample1 = self._make_upsampling_layer(network, stride=2)
    layers.append(upsample1)

    upsample1_split = self._make_split_layer(network)
    layers.append(upsample1_split)

    # First step has no bottom-up.
    unpool1 = []

    # Relu
    unpool1_relu.append(self._make_relu_layer(network))

    # top-down conv -> 14x14x32
    conv1.append(self._make_conv_layer(network, kernel_size=5, stride=1,
                                       num_output=32, pad=1, bias_value=0.1,
                                       shared_name='conv1', wtype=wtype,
                                       std=wf_std))

    # split conv1.
    conv1_split.append(self._make_split_layer(network))

    # conv1[-1] is now 14x14x32.
    # unpooling layer -> 28x28x32.
    upsample2.append(self._make_upsampling_layer(network, stride=2))

    # Add bottom-up and top-down.
    unpool2 = []

    # Relu
    unpool2_relu.append(self._make_relu_layer(network))

    # top-down conv -> 26x26x32
    conv2.append(self._make_conv_layer(network, kernel_size=5, stride=1,
                                       num_output=32, pad=1, bias_value=0.1,
                                       shared_name='conv2', wtype=wtype,
                                       std=wf_std))

    # split conv2.
    conv2_split.append(self._make_split_layer(network))

    # conv2[-1] is now 26x26x32.
    # unpooling layer -> 52x52x32.
    upsample3.append(self._make_upsampling_layer(network, stride=2))

    # Add bottom-up and top-down.
    unpool3 = []

    # Relu
    unpool3_relu.append(self._make_relu_layer(network))

    # top-down conv -> 48x48x1
    conv3.append(self._make_conv_layer(network, kernel_size=5, stride=1,
                                       num_output=1, pad=0, bias_value=0.0,
                                       shared_name='conv3', wtype=wtype,
                                       std=wf_std))
    for s in range(numstep-1):
      # deconv2 -> 16x16x32 from previous time step.
      deconv1.append(self._make_conv_layer(network, kernel_size=5, stride=1,
                                           num_output=32, pad=3, bias_value=0.1,
                                           shared_name='deconv1', wtype=wtype,
                                           std=wf_std))

      # Updated unpooling units for time s+1.
      unpool1.append(self._make_sum_layer(network))

      # Apply relu.
      unpool1_relu.append(self._make_relu_layer(network))

      # conv1 (-> 14x14x32).
      conv1.append(self._make_conv_layer(network, kernel_size=5, stride=1,
                                         num_output=32, pad=1, bias_value=0.1,
                                         shared_name='conv1', wtype=wtype,
                                         std=wf_std))

      # split conv1.
      conv1_split.append(self._make_split_layer(network))

      # conv1[-1] is now 14x14x32.
      # unpooling layer -> 28x28x32.
      upsample2.append(self._make_upsampling_layer(network, stride=2))

      # deconv2 -> 28x28x32 from previous time step.
      deconv2.append(self._make_conv_layer(network, kernel_size=5, stride=1,
                                           num_output=32, pad=3, bias_value=0.1,
                                           shared_name='deconv2', wtype=wtype,
                                           std=wf_std))

      # Add bottom-up and top-down (28x28x32).
      unpool2.append(self._make_sum_layer(network))

      # Relu
      unpool2_relu.append(self._make_relu_layer(network))

      # top-down conv -> 26x26x32
      conv2.append(self._make_conv_layer(network, kernel_size=5, stride=1,
                                         num_output=32, pad=1, bias_value=0.1,
                                         shared_name='conv2', wtype=wtype,
                                         std=wf_std))
      # split conv2.
      conv2_split.append(self._make_split_layer(network))

      # conv2[-1] is now 26x26x32.
      # unpooling layer -> 52x52x32.
      upsample3.append(self._make_upsampling_layer(network, stride=2))

      # deconv3 -> 52x52x32 from previous time step.
      deconv3.append(self._make_conv_layer(network, kernel_size=5, stride=1,
                                           num_output=32, pad=4, bias_value=0.1,
                                           shared_name='deconv3', wtype=wtype,
                                           std=wf_std))

      # Add bottom-up and top-down (52x52x32).
      unpool3.append(self._make_sum_layer(network))

      # Relu
      unpool3_relu.append(self._make_relu_layer(network))

      # top-down conv -> 48x48x1
      conv3.append(self._make_conv_layer(network, kernel_size=5, stride=1,
                                         num_output=1, pad=0, bias_value=0.0,
                                         shared_name='conv3', wtype=wtype,
                                         std=wf_std))

    conv3[-1].top.append('conv48')

    layers += deconv1
    layers += unpool1
    layers += unpool1_relu
    layers += conv1
    layers += conv1_split
    layers += upsample2
    layers += deconv2
    layers += unpool2
    layers += unpool2_relu
    layers += conv2
    layers += conv2_split
    layers += upsample3
    layers += deconv3
    layers += unpool3
    layers += unpool3_relu
    layers += conv3

    connections = [
      (id_ip1.name, (id_ip1.top, id_relu1.bottom, id_relu1.top, concat.bottom)),
      (var_ip1.name, (var_ip1.top, var_relu1.bottom, var_relu1.top, concat.bottom)),
      (concat.name, (concat.top, ip2.bottom)),
      (ip2.name, (ip2.top, relu2.bottom, relu2.top, ip3.bottom)),
      (ip3.name, (ip3.top, relu3.bottom, relu3.top, ip4.bottom)),
      (ip4.name, (ip4.top, relu4.bottom, relu4.top, folding.bottom)),
      (folding.name, (folding.top, upsample1.bottom)),
      (upsample1.name, (upsample1.top, upsample1_split.bottom))
    ]

    for s in range(numstep):
      ############ FIRST DECONV ############
      if s > 0:
        # upsample1 -> unpool1.
        connections.append((upsample1_split.name, (upsample1_split.top, unpool1[s-1].bottom)))
        # deconv1 -> unpool1.
        connections.append((deconv1[s-1].name, (deconv1[s-1].top, unpool1[s-1].bottom)))
        # unpool1 -> unpool1_relu.
        connections.append((unpool1[s-1].name, (unpool1[s-1].top, unpool1_relu[s].bottom)))
      else:
        # directly upsample1 -> unpool1_relu.
        connections.append((upsample1_split.name, (upsample1_split.top, unpool1_relu[s].bottom)))
      # unpool1_relu -> conv1.
      connections.append((unpool1_relu[s].name, (unpool1_relu[s].top, conv1[s].bottom)))
      # conv1 -> conv1_split.
      connections.append((conv1[s].name, (conv1[s].top, conv1_split[s].bottom)))
      if (s+1) < numstep:
        # conv1_split -> deconv1.
        connections.append((conv1_split[s].name, (conv1_split[s].top, deconv1[s].bottom)))
      # conv1_split -> upsample2.
      connections.append((conv1_split[s].name, (conv1_split[s].top, upsample2[s].bottom)))

      ############ SECOND DECONV ############
      if s > 0:
        # upsample2 -> unpool2.
        connections.append((upsample2[s].name, (upsample2[s].top, unpool2[s-1].bottom)))
        # deconv2 -> unpool2.
        connections.append((deconv2[s-1].name, (deconv2[s-1].top, unpool2[s-1].bottom)))
        # unpool2 -> unpool2_relu.
        connections.append((unpool2[s-1].name, (unpool2[s-1].top, unpool2_relu[s].bottom)))
      else:
        # directly upsample2 -> unpool2_relu.
        connections.append((upsample2[s].name, (upsample2[s].top, unpool2_relu[s].bottom)))
      # unpool2_relu -> conv2.
      connections.append((unpool2_relu[s].name, (unpool2_relu[s].top, conv2[s].bottom)))
      # conv2 -> conv2_split.
      connections.append((conv2[s].name, (conv2[s].top, conv2_split[s].bottom)))
      if (s+1) < numstep:
        # conv2_split -> deconv2.
        connections.append((conv2_split[s].name, (conv2_split[s].top, deconv2[s].bottom)))
      # conv2_split -> upsample3.
      connections.append((conv2_split[s].name, (conv2_split[s].top, upsample3[s].bottom)))

      ############ THIRD DECONV ############
      if s > 0:
        # upsample3 -> unpool3.
        connections.append((upsample3[s].name, (upsample3[s].top, unpool3[s-1].bottom)))
        # deconv3 -> unpool3.
        connections.append((deconv3[s-1].name, (deconv3[s-1].top, unpool3[s-1].bottom)))
        # unpool3 -> unpool3_relu.
        connections.append((unpool3[s-1].name, (unpool3[s-1].top, unpool3_relu[s].bottom)))
      else:
        # directly upsample3 -> unpool3_relu.
        connections.append((upsample3[s].name, (upsample3[s].top, unpool3_relu[s].bottom)))
      # unpool3_relu -> conv3.
      connections.append((unpool3_relu[s].name, (unpool3_relu[s].top, conv3[s].bottom)))
      if (s+1) < numstep:
        # conv3 -> deconv3.
        connections.append((conv3[s].name, (conv3[s].top, deconv3[s].bottom)))

    for connection in connections:
      self._tie(connection, name_generator)
    return layers


  def _make_reconstruction(self, network, name_generator):
    layers = []

    # ID to hid.
    id_ip1 = self._make_inner_product_layer(network, num_output=256)
    id_ip1.bottom.append('id')
    layers.append(id_ip1)

    id_relu1 = self._make_relu_layer(network)
    layers.append(id_relu1)

    # Var to hid.
    var_ip1 = self._make_inner_product_layer(network, num_output=256)
    var_ip1.bottom.append('var')
    layers.append(var_ip1)

    var_relu1 = self._make_relu_layer(network)
    layers.append(var_relu1)

    # Concat.
    concat = self._make_concat_layer(network)
    layers.append(concat)

    # Another FC layer.
    ip2 = self._make_inner_product_layer(network, num_output=512)
    layers.append(ip2)
    
    relu2 = self._make_relu_layer(network)
    layers.append(relu2)

    # Another FC layer.
    ip3 = self._make_inner_product_layer(network, num_output=512)
    layers.append(ip3)
    
    relu3 = self._make_relu_layer(network)
    layers.append(relu3)

    # Another FC layer.
    ip4 = self._make_inner_product_layer(network, num_output=2048)
    layers.append(ip4)

    relu4 = self._make_relu_layer(network)
    layers.append(relu4)

    # Another FC layer.
    ip5 = self._make_inner_product_layer(network, num_output=2304)
    layers.append(ip5)

    # Reshape into tensor.
    folding = self._make_folding_layer(network, 1, 48, 48)
    layers.append(folding)
    folding.top.append(folding.name)

    connections = [
      (id_ip1.name, (id_ip1.top, id_relu1.bottom, id_relu1.top, concat.bottom)),
      (var_ip1.name, (var_ip1.top, var_relu1.bottom, var_relu1.top, concat.bottom)),
      (concat.name, (concat.top, ip2.bottom)),
      (ip2.name, (ip2.top, relu2.bottom, relu2.top, ip3.bottom)),
      (ip3.name, (ip3.top, relu3.bottom, relu3.top, ip4.bottom)),
      (ip4.name, (ip4.top, relu4.bottom, relu4.top, ip5.bottom)),
      (ip5.name, (ip5.top, folding.bottom))
    ]
    for connection in connections:
      self._tie(connection, name_generator)
    return layers

  def _build_id_network(self):
    network = caffe_pb2.NetParameter()
    network.name = 'id_column'
    network.input.append('ref')
    network.input_dim.append(20)
    network.input_dim.append(1)
    network.input_dim.append(48)
    network.input_dim.append(48)
    network.input.append('same')
    network.input_dim.append(20)
    network.input_dim.append(1)
    network.input_dim.append(48)
    network.input_dim.append(48)
    network.input.append('diff')
    network.input_dim.append(20)
    network.input_dim.append(1)
    network.input_dim.append(48)
    network.input_dim.append(48)
    network.force_backward = True

    layers = []

    name_generator = self._connection_name_generator()

    # Conv and FC layers will share params among columns.
    param_types = [ LayerType.Value('INNER_PRODUCT'),
                    LayerType.Value('CONVOLUTION') ]

    # Generate the columns.
    for column in [ 'ref', 'same', 'diff' ]:
      # ID column.
      id_layers = self._make_column(network, name_generator,
                                    emb_prefix='%s_%s_' % (column, 'id'))
      id_layers[0].bottom.append(column)
      id_layers[-1].top.append('%s_%s' % (column, 'id'))
      for pos, layer in enumerate(id_layers):
        if (layer.type not in param_types):
          continue
        shared_name_w = 'id_shared%d_w' % pos 
        shared_name_b = 'id_shared%d_b' % pos 
        layer.param.append(shared_name_w)
        layer.param.append(shared_name_b)
      layers += id_layers

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos
    return network

  def _build_var_network(self):
    network = caffe_pb2.NetParameter()
    network.name = 'var_column'
    network.input.append('ref')
    network.input_dim.append(20)
    network.input_dim.append(1)
    network.input_dim.append(48)
    network.input_dim.append(48)
    network.input.append('same')
    network.input_dim.append(20)
    network.input_dim.append(1)
    network.input_dim.append(48)
    network.input_dim.append(48)
    network.input.append('diff')
    network.input_dim.append(20)
    network.input_dim.append(1)
    network.input_dim.append(48)
    network.input_dim.append(48)
    network.force_backward = True

    layers = []

    name_generator = self._connection_name_generator()

    # Conv and FC layers will share params among columns.
    param_types = [ LayerType.Value('INNER_PRODUCT'),
                    LayerType.Value('CONVOLUTION') ]

    # Generate the columns.
    for column in [ 'ref', 'same', 'diff' ]:
      # Var column.
      var_layers = self._make_column(network, name_generator,
                                     emb_prefix='%s_%s_' % (column, 'var'))
      var_layers[0].bottom.append(column)
      var_layers[-1].top.append('%s_%s' % (column, 'var'))
      for pos, layer in enumerate(var_layers):
        if (layer.type not in param_types):
          continue
        shared_name_w = 'var_shared%d_w' % pos 
        shared_name_b = 'var_shared%d_b' % pos 
        layer.param.append(shared_name_w)
        layer.param.append(shared_name_b)
      layers += var_layers

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos
    return network

  def _build_recon_recurrent_network(self, numstep=2):
    network = caffe_pb2.NetParameter()
    network.name = 'recon_column'
    network.input.append('id')
    network.input_dim.append(20)
    network.input_dim.append(64)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.input.append('var')
    network.input_dim.append(20)
    network.input_dim.append(64)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.force_backward = True

    layers = []

    name_generator = self._connection_name_generator()
    recon_layers = self._make_reconstruction_recurrent(network, name_generator, numstep)
    layers += recon_layers

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos
    return network

  def _build_recon_network(self):
    network = caffe_pb2.NetParameter()
    network.name = 'recon_column'
    network.input.append('id')
    network.input_dim.append(20)
    network.input_dim.append(64)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.input.append('var')
    network.input_dim.append(20)
    network.input_dim.append(64)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.force_backward = True

    layers = []

    name_generator = self._connection_name_generator()

    # Conv and FC layers will share params among columns.
    param_types = [ LayerType.Value('INNER_PRODUCT'),
                    LayerType.Value('CONVOLUTION') ]

    recon_layers = self._make_reconstruction(network, name_generator)
    for pos, layer in enumerate(recon_layers):
      if (layer.type not in param_types):
        continue
      shared_name_w = 'recon_shared%d_w' % pos
      shared_name_b = 'recon_shared%d_b' % pos
      layer.param.append(shared_name_w)
      layer.param.append(shared_name_b)
    layers += recon_layers

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos
    return network

  def _build_network(self):
    """Build the Disentangling network."""

    network = caffe_pb2.NetParameter()
    network.name = 'disentangle'
    network.input.append('ref')
    network.input_dim.append(20)
    network.input_dim.append(1)
    network.input_dim.append(48)
    network.input_dim.append(48)
    network.input.append('same')
    network.input_dim.append(20)
    network.input_dim.append(1)
    network.input_dim.append(48)
    network.input_dim.append(48)
    network.input.append('diff')
    network.input_dim.append(20)
    network.input_dim.append(1)
    network.input_dim.append(48)
    network.input_dim.append(48)

    layers = []

    name_generator = self._connection_name_generator()

    # Conv and FC layers will share params among columns.
    param_types = [ LayerType.Value('INNER_PRODUCT'),
                    LayerType.Value('CONVOLUTION') ]

    # Generate the columns.
    for column in [ 'ref', 'same', 'diff' ]:
      split = self._make_split_layer(network)
      split.bottom.append(column)
      layers.append(split)

      # ID column.
      id_layers = self._make_column(network, split, name_generator,
                                    emb_prefix='%s_%s_' % (column, 'id'))
      id_layers[-1].top.append('%s_%s' % (column, 'id'))
      for pos, layer in enumerate(id_layers):
        if (layer.type not in param_types):
          continue
        shared_name_w = 'id_shared%d_w' % pos 
        shared_name_b = 'id_shared%d_b' % pos 
        layer.param.append(shared_name_w)
        layer.param.append(shared_name_b)
      layers += id_layers

      # Var column.
      var_layers = self._make_column(network, split, name_generator,
                                     emb_prefix='%s_%s_' % (column, 'var'))
      var_layers[-1].top.append('%s_%s' % (column, 'var'))
      for pos, layer in enumerate(var_layers):
        if (layer.type not in param_types):
          continue
        shared_name_w = 'var_shared%d_w' % pos 
        shared_name_b = 'var_shared%d_b' % pos 
        layer.param.append(shared_name_w)
        layer.param.append(shared_name_b)
      layers += var_layers

      # Recon column.
      recon_layers = self._make_reconstruction(network, id_layers[-1],
          var_layers[-1], name_generator, prefix='%s_recon_' % column)
      for pos, layer in enumerate(recon_layers):
        if (layer.type not in param_types):
          continue
        shared_name_w = 'recon_shared%d_w' % pos 
        shared_name_b = 'recon_shared%d_b' % pos 
        layer.param.append(shared_name_w)
        layer.param.append(shared_name_b)
      layers += recon_layers

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def build_id(self):
    """main method."""

    network = self._build_id_network()
    print network
    network_filename = 'multipie_id_deploy.prototxt'
    with open(network_filename, 'w') as network_file:
      network_file.write(text_format.MessageToString(network))
    return Net(network_filename)

  def build_var(self):
    """main method."""

    network = self._build_var_network()
    print network
    network_filename = 'multipie_var_deploy.prototxt'
    with open(network_filename, 'w') as network_file:
      network_file.write(text_format.MessageToString(network))
    return Net(network_filename)

  def build_recon_recurrent(self):
    """main method."""

    network = self._build_recon_recurrent_network(numstep=2)
    #network = self._build_recon_recurrent_network(numstep=1)
    #print network
    network_filename = 'multipie_recon_deploy.prototxt'
    with open(network_filename, 'w') as network_file:
      network_file.write(text_format.MessageToString(network))
    return Net(network_filename)

  def build_recon(self):
    """main method."""

    network = self._build_recon_network()
    print network
    network_filename = 'multipie_recon_deploy.prototxt'
    with open(network_filename, 'w') as network_file:
      network_file.write(text_format.MessageToString(network))
    return Net(network_filename)

  def build(self):
    """main method."""

    network = self._build_network()
    print network
    network_filename = 'multipie_deploy.prototxt'
    with open(network_filename, 'w') as network_file:
      network_file.write(text_format.MessageToString(network))
    return Net(network_filename)

if __name__ == '__main__':
  __disentangle_builder__ = DisentangleBuilder(
    training_batch_size=20,
    testing_batch_size=20,
  )

  #__disentangle_builder__.build()
  __disentangle_builder__.build_id()
  __disentangle_builder__.build_var()
  #__disentangle_builder__.build_recon()
  __disentangle_builder__.build_recon_recurrent()
