# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
# pylint: disable=no-self-use
# pylint: disable=too-few-public-methods

from caffe.proto import caffe_pb2
from caffe import Net

from google.protobuf import text_format

# pylint: disable=invalid-name
# pylint: disable=no-member
LayerType = caffe_pb2.LayerParameter.LayerType
EltwiseOp = caffe_pb2.EltwiseParameter.EltwiseOp
PoolMethod = caffe_pb2.PoolingParameter.PoolMethod
DBType = caffe_pb2.DataParameter.DB
# pylint: enable=invalid-name
# pylint: enable=no-member

batchsize = 1;

class NetworkBuilder(object):

  def __init__(self, training_batch_size=20, testing_batch_size=20, **kwargs):
    self.training_batch_size = training_batch_size
    self.testing_batch_size = testing_batch_size
    self.other_args = kwargs


  def _make_inception(self, network, x1x1, x3x3r, x3x3, x5x5r, x5x5, proj,
                      name_generator):
    """Make Inception submodule."""

    layers = []

    split = self._make_split_layer(network)
    layers.append(split)

    context1 = self._make_conv_layer(network, kernel_size=1, num_output=x1x1,
                                     bias_value=0)
    layers.append(context1)

    relu1 = self._make_relu_layer(network)
    layers.append(relu1)

    context2a = self._make_conv_layer(network, kernel_size=1, num_output=x3x3r,
                                      bias_value=0)
    layers.append(context2a)

    relu2a = self._make_relu_layer(network)
    layers.append(relu2a)

    context2b = self._make_conv_layer(network, kernel_size=3, num_output=x3x3,
                                      pad=1)
    layers.append(context2b)

    relu2b = self._make_relu_layer(network)
    layers.append(relu2b)

    context3a = self._make_conv_layer(network, kernel_size=1, num_output=x5x5r,
                                      bias_value=0)
    layers.append(context3a)

    relu3a = self._make_relu_layer(network)
    layers.append(relu3a)

    context3b = self._make_conv_layer(network, kernel_size=5, num_output=x5x5,
                                      pad=2)
    layers.append(context3b)

    relu3b = self._make_relu_layer(network)
    layers.append(relu3b)

    context4a = self._make_maxpool_layer(network, kernel_size=3)
    layers.append(context4a)

    relu4a = self._make_relu_layer(network)
    layers.append(relu4a)

    context4b = self._make_conv_layer(network, kernel_size=1, num_output=proj,
                                      pad=1, bias_value=0)
    layers.append(context4b)

    relu4b = self._make_relu_layer(network)
    layers.append(relu4b)

    concat = self._make_concat_layer(network)
    layers.append(concat)

    connections = [
      (split.name, (split.top, context1.bottom)),
      (split.name, (split.top, context2a.bottom)),
      (split.name, (split.top, context3a.bottom)),
      (split.name, (split.top, context4a.bottom)),
      (context2a.name,
          (context2a.top, relu2a.bottom, relu2a.top, context2b.bottom)),
      (context3a.name,
          (context3a.top, relu3a.bottom, relu3a.top, context3b.bottom)),
      (context4a.name,
          (context4a.top, relu4a.bottom, relu4a.top, context4b.bottom)),
      (context1.name, (context1.top, relu1.bottom, relu1.top, concat.bottom)),
      (context2b.name,
          (context2b.top, relu2b.bottom, relu2b.top, concat.bottom)),
      (context3b.name,
          (context3b.top, relu3b.bottom, relu3b.top, concat.bottom)),
      (context4b.name,
          (context4b.top, relu4b.bottom, relu4b.top, concat.bottom)),
    ]

    for connection in connections:
      self._tie(connection, name_generator)

    return layers

  def _make_prod_layer(self, network, coeff=None):
    layer = network.layers.add()
    layer.name = 'prod'
    layer.type = LayerType.Value('ELTWISE')
    params = layer.eltwise_param
    params.operation = EltwiseOp.Value('PROD')
    if coeff:
      for c in coeff:
        params.coeff.append(c)
    return layer

  def _make_sum_layer(self, network, coeff=None):
    layer = network.layers.add()
    layer.name = 'sum'
    layer.type = LayerType.Value('ELTWISE')
    params = layer.eltwise_param
    params.operation = EltwiseOp.Value('SUM')
    if coeff:
      for c in coeff:
        params.coeff.append(c)
    return layer

  def _make_upsampling_layer(self, network, stride):
    layer = network.layers.add()
    layer.name = 'upsample'
    layer.type = LayerType.Value('UPSAMPLING')
    params = layer.upsampling_param
    params.kernel_size = stride
    return layer

  def _make_folding_layer(self, network, channels, height, width, prefix=''):
    layer = network.layers.add()
    layer.name = '%sfolding' % (prefix)
    layer.type = LayerType.Value('FOLDING')
    params = layer.folding_param
    params.channels_folded = channels
    params.height_folded = height
    params.width_folded = width
    return layer

  def _make_conv_layer(self, network, kernel_size, num_output, stride=1, pad=0,
                       bias_value=0.1, shared_name=None, wtype='xavier', std=0.01):
    """Make convolution layer."""

    layer = network.layers.add()
    layer.name = 'conv_%dx%d_%d' % (kernel_size, kernel_size, stride)

    layer.type = LayerType.Value('CONVOLUTION')
    params = layer.convolution_param
    params.num_output = num_output
    params.kernel_size = kernel_size
    params.stride = stride
    params.pad = pad
    weight_filler = params.weight_filler
    weight_filler.type = wtype
    if weight_filler.type == 'gaussian':
      weight_filler.mean = 0
      weight_filler.std = std
    bias_filler = params.bias_filler
    bias_filler.type = 'constant'
    bias_filler.value = bias_value

    layer.blobs_lr.append(1)
    layer.blobs_lr.append(2)

    layer.weight_decay.append(1)
    layer.weight_decay.append(0)

    if shared_name:
      layer.param.append('%s_w' % shared_name)
      layer.param.append('%s_b' % shared_name)

    return layer

  def _make_maxpool_layer(self, network, kernel_size, stride=1):
    """Make max pooling layer."""

    layer = network.layers.add()
    layer.name = 'maxpool_%dx%d_%d' % (kernel_size, kernel_size, stride)

    layer.type = LayerType.Value('POOLING')
    params = layer.pooling_param
    params.pool = PoolMethod.Value('MAX')
    params.kernel_size = kernel_size
    params.stride = stride

    return layer

  def _make_avgpool_layer(self, network, kernel_size, stride=1):
    """Make average pooling layer."""

    layer = network.layers.add()
    layer.name = 'avgpool_%dx%d_%d' % (kernel_size, kernel_size, stride)

    layer.type = LayerType.Value('POOLING')
    params = layer.pooling_param
    params.pool = PoolMethod.Value('AVE')
    params.kernel_size = kernel_size
    params.stride = stride

    return layer

  def _make_lrn_layer(self, network, name='lrn'):
    """Make local response normalization layer."""

    layer = network.layers.add()
    layer.name = name

    layer.type = LayerType.Value('LRN')
    params = layer.lrn_param
    params.local_size = 5
    params.alpha = 0.0001
    params.beta = 0.75

    return layer

  def _make_concat_layer(self, network):
    """Make depth concatenation layer."""

    layer = network.layers.add()
    layer.name = 'concat'

    layer.type = LayerType.Value('CONCAT')

    return layer

  def _make_dropout_layer(self, network, dropout_ratio=0.5):
    """Make dropout layer."""

    layer = network.layers.add()
    layer.name = 'dropout'

    layer.type = LayerType.Value('DROPOUT')
    params = layer.dropout_param
    params.dropout_ratio = dropout_ratio

    return layer

  def _make_inner_product_layer(self, network, num_output, weight_lr=1,
                                bias_lr=2, bias_value=0.1, prefix='',
                                shared_name=None,
                                wtype='xavier', std=0.01):
    """Make inner product layer."""

    layer = network.layers.add()
    layer.name = '%sinner_product' % prefix

    layer.type = LayerType.Value('INNER_PRODUCT')
    params = layer.inner_product_param
    params.num_output = num_output
    weight_filler = params.weight_filler
    weight_filler.type = wtype
    if wtype == 'gaussian':
      weight_filler.mean = 0
      weight_filler.std = std
    bias_filler = params.bias_filler
    bias_filler.type = 'constant'
    bias_filler.value = bias_value

    layer.blobs_lr.append(weight_lr)
    layer.blobs_lr.append(bias_lr)

    layer.weight_decay.append(1)
    layer.weight_decay.append(0)

    if shared_name:
      layer.param.append('%s_w' % shared_name)
      layer.param.append('%s_b' % shared_name)

    return layer

  def _make_split_layer(self, network):
    """Make split layer."""

    layer = network.layers.add()
    layer.name = 'split'

    layer.type = LayerType.Value('SPLIT')

    return layer

  def _make_relu_layer(self, network):
    """Make ReLU layer."""

    layer = network.layers.add()
    layer.name = 'relu'

    layer.type = LayerType.Value('RELU')

    return layer

  def _tie(self, layers, name_generator):
    """Generate a named connection between layer endpoints."""

    name = 'ep_%s_%d' % (layers[0], name_generator.next())
    for layer in layers[1]:
      layer.append(name)

  def _connection_name_generator(self):
    """Generate a unique id."""

    index = 0
    while True:
      yield index
      index += 1

  def _build_small_analogy_network(self, wtype='xavier', std=0.01, batchsize=20):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'analogy_small'
    network.input.append('ref')
    network.input_dim.append(batchsize)
    network.input_dim.append(21*21*3)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.input.append('out')
    network.input_dim.append(batchsize)
    network.input_dim.append(21*21*3)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.input.append('query')
    network.input_dim.append(batchsize)
    network.input_dim.append(21*21*3)
    network.input_dim.append(1)
    network.input_dim.append(1)

    layers = []
    name_generator = self._connection_name_generator()

    inputs = [ 'ref', 'out', 'query' ]
    fc1 = []
    relu1 = []
    fc2 = []
    relu2 = []
    fc3 = []
    relu3 = []
    fc4 = []

    for inp in inputs:
      fc1.append(self._make_inner_product_layer(network, num_output=512, shared_name='ip1'))
      fc1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))

      fc2.append(self._make_inner_product_layer(network, num_output=256, shared_name='ip2'))

    fc4_sum = self._make_sum_layer(network, coeff=[-1,1,1])

    fc5 = self._make_inner_product_layer(network, num_output=512)
    relu5 = self._make_relu_layer(network)

    fc6 = self._make_inner_product_layer(network, num_output=21*21*3)
    fc6.top.append('prediction')

    layers += fc1
    layers += relu1
    layers += fc2
    layers += [ fc4_sum, fc5, relu5, fc6 ]

    connections = []
    for i in range(3):
      connections.append((fc1[i].name, (fc1[i].top, relu1[i].bottom, relu1[i].top, fc2[i].bottom)))
    connections.append((fc2[0].name, (fc2[0].top, fc4_sum.bottom)))
    connections.append((fc2[1].name, (fc2[1].top, fc4_sum.bottom)))
    connections.append((fc2[2].name, (fc2[2].top, fc4_sum.bottom)))
    connections.append((fc4_sum.name, (fc4_sum.top, fc5.bottom)))
    connections.append((fc5.name, (fc5.top, relu5.bottom, relu5.top, fc6.bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_shape_dec_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'shapes_dec'
    network.input.append('hid')
    network.input_dim.append(batchsize)
    network.input_dim.append(512)
    network.input_dim.append(1)
    network.input_dim.append(1)

    # RGB decoder
    fc1 = self._make_inner_product_layer(network, num_output=1024)
    fc1.bottom.append('hid')
    relu1 = self._make_relu_layer(network)
    fc2 = self._make_inner_product_layer(network, num_output=21*21*3)
    folding = self._make_folding_layer(network, 3, 21, 21)
    folding.top.append('rgb')

    # Mask decoder
    fc1_mask = self._make_inner_product_layer(network, num_output=1024)
    fc1_mask.bottom.append('hid')
    relu1_mask = self._make_relu_layer(network)

    fc2_mask = self._make_inner_product_layer(network, num_output=21*21)
    folding_mask = self._make_folding_layer(network, 1, 21, 21)
    folding_mask.top.append('mask')

    layers = [ fc1, relu1, fc2, folding,
               fc1_mask, relu1_mask, fc2_mask, folding_mask ]
    name_generator = self._connection_name_generator()

    connections = []
    connections.append((fc1.name, (fc1.top, relu1.bottom, relu1.top, fc2.bottom)))
    connections.append((fc2.name, (fc2.top, folding.bottom)))
    connections.append((fc1_mask.name, (fc1_mask.top, relu1_mask.bottom, relu1_mask.top, fc2_mask.bottom)))
    connections.append((fc2_mask.name, (fc2_mask.top, folding_mask.bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_shape_enc_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'shape_enc'
    network.input.append('ref')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(21)
    network.input_dim.append(21)
    network.input.append('out')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(21)
    network.input_dim.append(21)
    network.input.append('query')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(21)
    network.input_dim.append(21)

    layers = []
    name_generator = self._connection_name_generator()

    inputs = [ 'ref', 'out', 'query' ]
    fc1 = []
    relu1 = []
    fc2 = []

    for inp in inputs:
      fc1.append(self._make_inner_product_layer(network, num_output=1024, shared_name='ip1'))
      fc1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))

      fc2.append(self._make_inner_product_layer(network, num_output=512, shared_name='ip2'))
      fc2[-1].top.append('%s_hid'% inp)

    layers += fc1
    layers += relu1
    layers += fc2

    connections = []
    for i in range(3):
      connections.append((fc1[i].name, (fc1[i].top, relu1[i].bottom, relu1[i].top, fc2[i].bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_sprite_dec_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'sprites_dec'
    network.input.append('hid')
    network.input_dim.append(batchsize)
    network.input_dim.append(1024)
    network.input_dim.append(1)
    network.input_dim.append(1)

    # RGB decoder
    fc1 = self._make_inner_product_layer(network, num_output=1024)
    fc1.bottom.append('hid')
    relu1 = self._make_relu_layer(network)

    fc2 = self._make_inner_product_layer(network, num_output=2048)
    relu2 = self._make_relu_layer(network)

    fc3 = self._make_inner_product_layer(network, num_output=4096)
    relu3 = self._make_relu_layer(network)

    fc4 = self._make_inner_product_layer(network, num_output=60*60*3)

    folding = self._make_folding_layer(network, 3, 60, 60)
    folding.top.append('rgb')

    # Mask decoder
    fc1_mask = self._make_inner_product_layer(network, num_output=1024)
    fc1_mask.bottom.append('hid')
    relu1_mask = self._make_relu_layer(network)

    fc2_mask = self._make_inner_product_layer(network, num_output=2048)
    relu2_mask = self._make_relu_layer(network)

    fc3_mask = self._make_inner_product_layer(network, num_output=4096)
    relu3_mask = self._make_relu_layer(network)

    fc4_mask = self._make_inner_product_layer(network, num_output=60*60*1)

    folding_mask = self._make_folding_layer(network, 1, 60, 60)
    folding_mask.top.append('mask')

    layers = [ fc1, relu1, fc2, relu2, fc3, relu3, fc4, folding,
               fc1_mask, relu1_mask, fc2_mask, relu2_mask, fc3_mask,
               relu3_mask, fc4_mask, folding_mask ]
    name_generator = self._connection_name_generator()

    connections = []
    connections.append((fc1.name, (fc1.top, relu1.bottom, relu1.top, fc2.bottom)))
    connections.append((fc2.name, (fc2.top, relu2.bottom, relu2.top, fc3.bottom)))
    connections.append((fc3.name, (fc3.top, relu3.bottom, relu3.top, fc4.bottom)))
    connections.append((fc4.name, (fc4.top, folding.bottom)))
    connections.append((fc1_mask.name, (fc1_mask.top, relu1_mask.bottom, relu1_mask.top, fc2_mask.bottom)))
    connections.append((fc2_mask.name, (fc2_mask.top, relu2_mask.bottom, relu2_mask.top, fc3_mask.bottom)))
    connections.append((fc3_mask.name, (fc3_mask.top, relu3_mask.bottom, relu3_mask.top, fc4_mask.bottom)))
    connections.append((fc4_mask.name, (fc4_mask.top, folding_mask.bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_sprite_conv_enc2_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'sprite_enc'
    network.input.append('ref')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(60)
    network.input_dim.append(60)
    network.input.append('out')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(60)
    network.input_dim.append(60)
    network.input.append('query')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(60)
    network.input_dim.append(60)

    layers = []
    name_generator = self._connection_name_generator()

    inputs = [ 'ref', 'out', 'query' ]
    conv1 = [] # 60 -> 30
    relu1 = []
    conv2 = [] # 30 -> 15
    relu2 = []
    fc3 = []   # 15x15x24 -> 2048
    relu3 = []
    fc4 = []   # 2048 -> 1024
    relu4 = []
    fc5 = []   # 1024 -> 1024

    for inp in inputs:
      conv1.append(self._make_conv_layer(network, kernel_size=5, num_output=48, stride=2, pad=2,
                                         bias_value=0.1, shared_name='conv1'))
      conv1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))

      conv2.append(self._make_conv_layer(network, kernel_size=5, num_output=24, stride=2, pad=2,
                                         bias_value=0.1, shared_name='conv2'))
      relu2.append(self._make_relu_layer(network))

      fc3.append(self._make_inner_product_layer(network, num_output=2048, shared_name='ip3'))
      relu3.append(self._make_relu_layer(network))

      fc4.append(self._make_inner_product_layer(network, num_output=1024, shared_name='ip4'))
      relu4.append(self._make_relu_layer(network))

      fc5.append(self._make_inner_product_layer(network, num_output=1024, shared_name='ip5'))
      fc5[-1].top.append('%s_hid'% inp)

    layers += conv1
    layers += relu1
    layers += conv2
    layers += relu2
    layers += fc3
    layers += relu3
    layers += fc4
    layers += relu4
    layers += fc5

    connections = []
    for i in range(3):
      connections.append((conv1[i].name, (conv1[i].top, relu1[i].bottom, relu1[i].top, conv2[i].bottom)))
      connections.append((conv2[i].name, (conv2[i].top, relu2[i].bottom, relu2[i].top, fc3[i].bottom)))
      connections.append((fc3[i].name, (fc3[i].top, relu3[i].bottom, relu3[i].top, fc4[i].bottom)))
      connections.append((fc4[i].name, (fc4[i].top, relu4[i].bottom, relu4[i].top, fc5[i].bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_sprite_conv_enc_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'sprite_enc'
    network.input.append('ref')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(60)
    network.input_dim.append(60)
    network.input.append('out')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(60)
    network.input_dim.append(60)
    network.input.append('query')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(60)
    network.input_dim.append(60)
    network.input.append('target')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(60)
    network.input_dim.append(60)

    layers = []
    name_generator = self._connection_name_generator()

    inputs = [ 'ref', 'out', 'query', 'target' ]
    conv1 = [] # 60 -> 30
    relu1 = []
    conv2 = [] # 30 -> 15
    relu2 = []
    fc3 = []   # 15x15x24 -> 2048
    relu3 = []
    fc4 = []   # 2048 -> 1024

    for inp in inputs:
      conv1.append(self._make_conv_layer(network, kernel_size=5, num_output=64, stride=2, pad=2,
                                         bias_value=0.1, shared_name='conv1'))
      conv1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))

      conv2.append(self._make_conv_layer(network, kernel_size=5, num_output=32, stride=2, pad=2,
                                         bias_value=0.1, shared_name='conv2'))
      relu2.append(self._make_relu_layer(network))

      fc3.append(self._make_inner_product_layer(network, num_output=2048, shared_name='ip3'))
      relu3.append(self._make_relu_layer(network))

      fc4.append(self._make_inner_product_layer(network, num_output=1024, shared_name='ip4'))
      fc4[-1].top.append('%s_hid' % inp)

    layers += conv1
    layers += relu1
    layers += conv2
    layers += relu2
    layers += fc3
    layers += relu3
    layers += fc4

    connections = []
    for i in range(4):
      connections.append((conv1[i].name, (conv1[i].top, relu1[i].bottom, relu1[i].top, conv2[i].bottom)))
      connections.append((conv2[i].name, (conv2[i].top, relu2[i].bottom, relu2[i].top, fc3[i].bottom)))
      connections.append((fc3[i].name, (fc3[i].top, relu3[i].bottom, relu3[i].top, fc4[i].bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_sprite_conv_dec2_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'sprites_dec'
    network.input.append('hid')
    network.input_dim.append(batchsize)
    network.input_dim.append(1024)
    network.input_dim.append(1)
    network.input_dim.append(1)

    # RGB decoder
    fc1 = self._make_inner_product_layer(network, num_output=1024)
    fc1.bottom.append('hid')
    relu1 = self._make_relu_layer(network)
    fc2 = self._make_inner_product_layer(network, num_output=2048)
    relu2 = self._make_relu_layer(network)
    fc3 = self._make_inner_product_layer(network, num_output=15*15*24)
    relu3 = self._make_relu_layer(network)
    folding = self._make_folding_layer(network, 24, 15, 15)
    # 15 -> 30
    upsample1 = self._make_upsampling_layer(network, stride=2)
    # 30 -> (30 + 6 - 5 + 1 = 32)
    conv4 = self._make_conv_layer(network, kernel_size=5, num_output=24, stride=1, pad=3,
                                  bias_value=0.1)
    relu4 = self._make_relu_layer(network)
    # 32 -> 64
    upsample2 = self._make_upsampling_layer(network, stride=2)
    # 64-> 60
    conv5 = self._make_conv_layer(network, kernel_size=5, num_output=3, stride=1, pad=0,
                                  bias_value=0.1)
    conv5.top.append('rgb')

    # Mask decoder
    fc1_m = self._make_inner_product_layer(network, num_output=1024)
    fc1_m.bottom.append('hid')
    relu1_m = self._make_relu_layer(network)
    fc2_m = self._make_inner_product_layer(network, num_output=2048)
    relu2_m = self._make_relu_layer(network)
    fc3_m = self._make_inner_product_layer(network, num_output=15*15*24)
    relu3_m = self._make_relu_layer(network)
    folding_m = self._make_folding_layer(network, 24, 15, 15)
    # 15 -> 30
    upsample1_m = self._make_upsampling_layer(network, stride=2)
    # 30 -> (30 + 6 - 5 + 1 = 32)
    conv4_m = self._make_conv_layer(network, kernel_size=5, num_output=24, stride=1, pad=3,
                                    bias_value=0.1)
    relu4_m = self._make_relu_layer(network)
    # 32 -> 64
    upsample2_m = self._make_upsampling_layer(network, stride=2)
    # 64-> 60
    conv5_m = self._make_conv_layer(network, kernel_size=5, num_output=1, stride=1, pad=0,
                                    bias_value=0.1)
    conv5_m.top.append('mask')

    layers = [ fc1, relu1, fc2, relu2, fc3, relu3, folding, upsample1, conv4, relu4, upsample2, conv5,
               fc1_m, relu1_m, fc2_m, relu2_m, fc3_m, relu3_m, folding_m, upsample1_m, conv4_m, relu4_m, upsample2_m, conv5_m ]
    name_generator = self._connection_name_generator()

    connections = []
    connections.append((fc1.name, (fc1.top, relu1.bottom, relu1.top, fc2.bottom)))
    connections.append((fc2.name, (fc2.top, relu2.bottom, relu2.top, fc3.bottom)))
    connections.append((fc3.name, (fc3.top, relu3.bottom, relu3.top, folding.bottom)))
    connections.append((folding.name, (folding.top, upsample1.bottom)))
    connections.append((upsample1.name, (upsample1.top, conv4.bottom)))
    connections.append((conv4.name, (conv4.top, relu4.bottom, relu4.top, upsample2.bottom)))
    connections.append((upsample2.name, (upsample2.top, conv5.bottom)))

    connections.append((fc1_m.name, (fc1_m.top, relu1_m.bottom, relu1_m.top, fc2_m.bottom)))
    connections.append((fc2_m.name, (fc2_m.top, relu2_m.bottom, relu2_m.top, fc3_m.bottom)))
    connections.append((fc3_m.name, (fc3_m.top, relu3_m.bottom, relu3_m.top, folding_m.bottom)))
    connections.append((folding_m.name, (folding_m.top, upsample1_m.bottom)))
    connections.append((upsample1_m.name, (upsample1_m.top, conv4_m.bottom)))
    connections.append((conv4_m.name, (conv4_m.top, relu4_m.bottom, relu4_m.top, upsample2_m.bottom)))
    connections.append((upsample2_m.name, (upsample2_m.top, conv5_m.bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_sprite_imp_dec_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'sprites_dec'
    network.input.append('hid1')
    network.input_dim.append(batchsize)
    network.input_dim.append(1024)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.input.append('hid2')
    network.input_dim.append(batchsize)
    network.input_dim.append(1024)
    network.input_dim.append(1)
    network.input_dim.append(1)

    inputs = [ 'hid1', 'hid2' ]
    for inp in inputs:
      # RGB decoder
      fc1 = self._make_inner_product_layer(network, num_output=2048)
      fc1.bottom.append('hid')
      relu1 = self._make_relu_layer(network)
      fc2 = self._make_inner_product_layer(network, num_output=15*15*32)
      relu2 = self._make_relu_layer(network)
      folding = self._make_folding_layer(network, 32, 15, 15)
      # 15 -> 30
      upsample1 = self._make_upsampling_layer(network, stride=2)
      # 30 -> (30 + 6 - 5 + 1 = 32)
      conv3 = self._make_conv_layer(network, kernel_size=5, num_output=48, stride=1, pad=3,
                                    bias_value=0.1)
      relu3 = self._make_relu_layer(network)
      # 32 -> 64
      upsample2 = self._make_upsampling_layer(network, stride=2)
      # 64-> 60
      conv4 = self._make_conv_layer(network, kernel_size=5, num_output=3, stride=1, pad=0,
                                    bias_value=0.1)
      conv4.top.append('rgb')

      # Mask decoder
      fc1_m = self._make_inner_product_layer(network, num_output=2048)
      fc1_m.bottom.append('hid')
      relu1_m = self._make_relu_layer(network)
      fc2_m = self._make_inner_product_layer(network, num_output=15*15*24)
      relu2_m = self._make_relu_layer(network)
      folding_m = self._make_folding_layer(network, 24, 15, 15)
      # 15 -> 30
      upsample1_m = self._make_upsampling_layer(network, stride=2)
      # 30 -> (30 + 6 - 5 + 1 = 32)
      conv3_m = self._make_conv_layer(network, kernel_size=5, num_output=24, stride=1, pad=3,
                                      bias_value=0.1)
      relu3_m = self._make_relu_layer(network)
      # 32 -> 64
      upsample2_m = self._make_upsampling_layer(network, stride=2)
      # 64-> 60
      conv4_m = self._make_conv_layer(network, kernel_size=5, num_output=1, stride=1, pad=0,
                                      bias_value=0.1)
      conv4_m.top.append('mask')

    layers = [ fc1, relu1, fc2, relu2, folding, upsample1, conv3, relu3, upsample2, conv4,
               fc1_m, relu1_m, fc2_m, relu2_m, folding_m, upsample1_m, conv3_m, relu3_m, upsample2_m, conv4_m ]
    name_generator = self._connection_name_generator()

    connections = []
    connections.append((fc1.name, (fc1.top, relu1.bottom, relu1.top, fc2.bottom)))
    connections.append((fc2.name, (fc2.top, relu2.bottom, relu2.top, folding.bottom)))
    connections.append((folding.name, (folding.top, upsample1.bottom)))
    connections.append((upsample1.name, (upsample1.top, conv3.bottom)))
    connections.append((conv3.name, (conv3.top, relu3.bottom, relu3.top, upsample2.bottom)))
    connections.append((upsample2.name, (upsample2.top, conv4.bottom)))

    connections.append((fc1_m.name, (fc1_m.top, relu1_m.bottom, relu1_m.top, fc2_m.bottom)))
    connections.append((fc2_m.name, (fc2_m.top, relu2_m.bottom, relu2_m.top, folding_m.bottom)))
    connections.append((folding_m.name, (folding_m.top, upsample1_m.bottom)))
    connections.append((upsample1_m.name, (upsample1_m.top, conv3_m.bottom)))
    connections.append((conv3_m.name, (conv3_m.top, relu3_m.bottom, relu3_m.top, upsample2_m.bottom)))
    connections.append((upsample2_m.name, (upsample2_m.top, conv4_m.bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_sprite_conv_dec_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'sprites_dec'
    network.input.append('hid')
    network.input_dim.append(batchsize)
    network.input_dim.append(1024)
    network.input_dim.append(1)
    network.input_dim.append(1)

    # RGB decoder
    fc1 = self._make_inner_product_layer(network, num_output=2048)
    fc1.bottom.append('hid')
    relu1 = self._make_relu_layer(network)
    fc2 = self._make_inner_product_layer(network, num_output=15*15*32)
    relu2 = self._make_relu_layer(network)
    folding = self._make_folding_layer(network, 32, 15, 15)
    # 15 -> 30
    upsample1 = self._make_upsampling_layer(network, stride=2)
    # 30 -> (30 + 6 - 5 + 1 = 32)
    conv3 = self._make_conv_layer(network, kernel_size=5, num_output=48, stride=1, pad=3,
                                  bias_value=0.1)
    relu3 = self._make_relu_layer(network)
    # 32 -> 64
    upsample2 = self._make_upsampling_layer(network, stride=2)
    # 64-> 60
    conv4 = self._make_conv_layer(network, kernel_size=5, num_output=3, stride=1, pad=0,
                                  bias_value=0.1)
    conv4.top.append('rgb')

    # Mask decoder
    fc1_m = self._make_inner_product_layer(network, num_output=2048)
    fc1_m.bottom.append('hid')
    relu1_m = self._make_relu_layer(network)
    fc2_m = self._make_inner_product_layer(network, num_output=15*15*24)
    relu2_m = self._make_relu_layer(network)
    folding_m = self._make_folding_layer(network, 24, 15, 15)
    # 15 -> 30
    upsample1_m = self._make_upsampling_layer(network, stride=2)
    # 30 -> (30 + 6 - 5 + 1 = 32)
    conv3_m = self._make_conv_layer(network, kernel_size=5, num_output=24, stride=1, pad=3,
                                    bias_value=0.1)
    relu3_m = self._make_relu_layer(network)
    # 32 -> 64
    upsample2_m = self._make_upsampling_layer(network, stride=2)
    # 64-> 60
    conv4_m = self._make_conv_layer(network, kernel_size=5, num_output=1, stride=1, pad=0,
                                    bias_value=0.1)
    conv4_m.top.append('mask')

    layers = [ fc1, relu1, fc2, relu2, folding, upsample1, conv3, relu3, upsample2, conv4,
               fc1_m, relu1_m, fc2_m, relu2_m, folding_m, upsample1_m, conv3_m, relu3_m, upsample2_m, conv4_m ]
    name_generator = self._connection_name_generator()

    connections = []
    connections.append((fc1.name, (fc1.top, relu1.bottom, relu1.top, fc2.bottom)))
    connections.append((fc2.name, (fc2.top, relu2.bottom, relu2.top, folding.bottom)))
    connections.append((folding.name, (folding.top, upsample1.bottom)))
    connections.append((upsample1.name, (upsample1.top, conv3.bottom)))
    connections.append((conv3.name, (conv3.top, relu3.bottom, relu3.top, upsample2.bottom)))
    connections.append((upsample2.name, (upsample2.top, conv4.bottom)))

    connections.append((fc1_m.name, (fc1_m.top, relu1_m.bottom, relu1_m.top, fc2_m.bottom)))
    connections.append((fc2_m.name, (fc2_m.top, relu2_m.bottom, relu2_m.top, folding_m.bottom)))
    connections.append((folding_m.name, (folding_m.top, upsample1_m.bottom)))
    connections.append((upsample1_m.name, (upsample1_m.top, conv3_m.bottom)))
    connections.append((conv3_m.name, (conv3_m.top, relu3_m.bottom, relu3_m.top, upsample2_m.bottom)))
    connections.append((upsample2_m.name, (upsample2_m.top, conv4_m.bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_sprite_enc_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'sprite_enc'
    network.input.append('ref')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(60)
    network.input_dim.append(60)
    network.input.append('out')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(60)
    network.input_dim.append(60)
    network.input.append('query')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(60)
    network.input_dim.append(60)

    layers = []
    name_generator = self._connection_name_generator()

    inputs = [ 'ref', 'out', 'query' ]
    fc1 = []
    relu1 = []
    fc2 = []

    for inp in inputs:
      fc1.append(self._make_inner_product_layer(network, num_output=4096, shared_name='ip1'))
      fc1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))
      fc2.append(self._make_inner_product_layer(network, num_output=1024, shared_name='ip2'))
      fc2[-1].top.append('%s_hid'% inp)

    layers += fc1
    layers += relu1
    layers += fc2

    connections = []
    for i in range(3):
      connections.append((fc1[i].name, (fc1[i].top, relu1[i].bottom, relu1[i].top, fc2[i].bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_analogy_network(self, wtype='xavier', std=0.01, batchsize=20):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'analogy'
    network.input.append('ref')
    network.input_dim.append(batchsize)
    network.input_dim.append(21*21*3)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.input.append('out')
    network.input_dim.append(batchsize)
    network.input_dim.append(21*21*3)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.input.append('query')
    network.input_dim.append(batchsize)
    network.input_dim.append(21*21*3)
    network.input_dim.append(1)
    network.input_dim.append(1)

    layers = []
    name_generator = self._connection_name_generator()

    inputs = [ 'ref', 'out', 'query' ]
    fc1 = []
    relu1 = []
    fc2 = []
    relu2 = []
    fc3 = []
    relu3 = []
    fc4 = []

    for inp in inputs:
      fc1.append(self._make_inner_product_layer(network, num_output=512, shared_name='ip1'))
      fc1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))

      fc2.append(self._make_inner_product_layer(network, num_output=256, shared_name='ip2'))
      relu2.append(self._make_relu_layer(network))

      fc3.append(self._make_inner_product_layer(network, num_output=128, shared_name='ip3'))
      relu3.append(self._make_relu_layer(network))

      fc4.append(self._make_inner_product_layer(network, num_output=64, shared_name='ip4'))

    fc4_sum = self._make_sum_layer(network, coeff=[-1,1,1])

    fc5 = self._make_inner_product_layer(network, num_output=128)
    relu5 = self._make_relu_layer(network)

    fc6 = self._make_inner_product_layer(network, num_output=256)
    relu6 = self._make_relu_layer(network)
    
    fc7 = self._make_inner_product_layer(network, num_output=512)
    relu7 = self._make_relu_layer(network)

    fc8 = self._make_inner_product_layer(network, num_output=21*21*3)
    fc8.top.append('prediction')

    layers += fc1
    layers += relu1
    layers += fc2
    layers += relu2
    layers += fc3
    layers += relu3
    layers += fc4
    layers += [ fc4_sum, fc5, relu5, fc6, relu6, fc7, relu7, fc8 ]

    connections = []
    for i in range(3):
      connections.append((fc1[i].name, (fc1[i].top, relu1[i].bottom, relu1[i].top, fc2[i].bottom)))
      connections.append((fc2[i].name, (fc2[i].top, relu2[i].bottom, relu2[i].top, fc3[i].bottom)))
      connections.append((fc3[i].name, (fc3[i].top, relu3[i].bottom, relu3[i].top, fc4[i].bottom)))
    connections.append((fc4[0].name, (fc4[0].top, fc4_sum.bottom)))
    connections.append((fc4[1].name, (fc4[1].top, fc4_sum.bottom)))
    connections.append((fc4[2].name, (fc4[2].top, fc4_sum.bottom)))
    connections.append((fc4_sum.name, (fc4_sum.top, fc5.bottom)))
    connections.append((fc5.name, (fc5.top, relu5.bottom, relu5.top, fc6.bottom)))
    connections.append((fc6.name, (fc6.top, relu6.bottom, relu6.top, fc7.bottom)))
    connections.append((fc7.name, (fc7.top, relu7.bottom, relu7.top, fc8.bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_shape_enc_network3(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'shape_enc'
    network.input.append('ref')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(32)
    network.input_dim.append(32)
    network.input.append('out')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(32)
    network.input_dim.append(32)
    network.input.append('query')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(32)
    network.input_dim.append(32)
    network.input.append('target')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(32)
    network.input_dim.append(32)

    layers = []
    name_generator = self._connection_name_generator()

    inputs = [ 'ref', 'out', 'query', 'target' ]
    fc1 = []
    relu1 = []
    fc2 = []

    for inp in inputs:
      fc1.append(self._make_inner_product_layer(network, num_output=2048, shared_name='ip1'))
      fc1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))

      fc2.append(self._make_inner_product_layer(network, num_output=512, shared_name='ip2'))
      fc2[-1].top.append('%s_hid' % inp)

    layers += fc1
    layers += relu1
    layers += fc2

    connections = []
    for i in range(4):
      connections.append((fc1[i].name, (fc1[i].top, relu1[i].bottom, relu1[i].top, fc2[i].bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_shape_enc_network2(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'shape_enc'
    network.input.append('ref')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(32)
    network.input_dim.append(32)
    network.input.append('out')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(32)
    network.input_dim.append(32)
    network.input.append('query')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(32)
    network.input_dim.append(32)
    network.input.append('target')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(32)
    network.input_dim.append(32)

    layers = []
    name_generator = self._connection_name_generator()

    inputs = [ 'ref', 'out', 'query', 'target' ]
    conv1 = [] # 32 -> 16
    relu1 = []
    fc2 = []   # 16x16x24 -> 2048
    relu2 = []
    fc3 = []   # 2048 -> 1024

    for inp in inputs:
      conv1.append(self._make_conv_layer(network, kernel_size=5, num_output=64, stride=2, pad=2,
                                         bias_value=0.1, shared_name='conv1'))
      conv1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))

      fc2.append(self._make_inner_product_layer(network, num_output=2048, shared_name='ip3'))
      relu2.append(self._make_relu_layer(network))

      fc3.append(self._make_inner_product_layer(network, num_output=512, shared_name='ip4'))
      fc3[-1].top.append('%s_hid' % inp)

    layers += conv1
    layers += relu1
    layers += fc2
    layers += relu2
    layers += fc3

    connections = []
    for i in range(4):
      connections.append((conv1[i].name, (conv1[i].top, relu1[i].bottom, relu1[i].top, fc2[i].bottom)))
      connections.append((fc2[i].name, (fc2[i].top, relu2[i].bottom, relu2[i].top, fc3[i].bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network
    
  def _build_3d_enc_small_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = '3d_enc'
    network.input.append('ref')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('out')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('query')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('target')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)

    layers = []
    name_generator = self._connection_name_generator()

    inputs = [ 'ref', 'out', 'query', 'target' ]
    fc1 = []
    relu1 = []
    fc2 = []
    relu2 = []
    fc3 = []

    for inp in inputs:
      fc1.append(self._make_inner_product_layer(network, num_output=2048, shared_name='ip1'))
      fc1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))
      fc2.append(self._make_inner_product_layer(network, num_output=1024, shared_name='ip2'))
      relu2.append(self._make_relu_layer(network))
      fc3.append(self._make_inner_product_layer(network, num_output=512, shared_name='ip3'))
      fc3[-1].top.append('%s_hid' % inp)

    layers += fc1
    layers += relu1
    layers += fc2
    layers += relu2
    layers += fc3

    connections = []
    for i in range(4):
      connections.append((fc1[i].name, (fc1[i].top, relu1[i].bottom, relu1[i].top, fc2[i].bottom)))
      connections.append((fc2[i].name, (fc2[i].top, relu2[i].bottom, relu2[i].top, fc3[i].bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_3d_enc_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = '3d_enc'
    network.input.append('ref')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('out')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('query')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('target')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)

    layers = []
    name_generator = self._connection_name_generator()

    inputs = [ 'ref', 'out', 'query', 'target' ]
    conv1 = [] # 64 -> 32
    relu1 = []
    conv2 = [] # 32 -> 16
    relu2 = []
    conv3 = []
    relu3 = []
    fc4 = []   # 16x16x24 -> 2048
    relu4 = []
    fc5 = []   # 2048 -> 1024

    for inp in inputs:
      conv1.append(self._make_conv_layer(network, kernel_size=5, num_output=64, stride=2, pad=2,
                                         bias_value=0.1, shared_name='conv1'))
      conv1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))

      conv2.append(self._make_conv_layer(network, kernel_size=5, num_output=32, stride=2, pad=2,
                                         bias_value=0.1, shared_name='conv2'))
      relu2.append(self._make_relu_layer(network))

      conv3.append(self._make_conv_layer(network, kernel_size=3, num_output=32, stride=1, pad=1, bias_value=0.1, shared_name='conv3'))
      relu3.append(self._make_relu_layer(network))

      fc4.append(self._make_inner_product_layer(network, num_output=2048, shared_name='ip4'))
      relu4.append(self._make_relu_layer(network))

      fc5.append(self._make_inner_product_layer(network, num_output=1024, shared_name='ip5'))
      fc5[-1].top.append('%s_hid' % inp)

    layers += conv1
    layers += relu1
    layers += conv2
    layers += relu2
    layers += conv3
    layers += relu3
    layers += fc4
    layers += relu4
    layers += fc5

    connections = []
    for i in range(4):
      connections.append((conv1[i].name, (conv1[i].top, relu1[i].bottom, relu1[i].top, conv2[i].bottom)))
      connections.append((conv2[i].name, (conv2[i].top, relu2[i].bottom, relu2[i].top, conv3[i].bottom)))
      connections.append((conv3[i].name, (conv3[i].top, relu3[i].bottom, relu3[i].top, fc4[i].bottom)))
      connections.append((fc4[i].name, (fc4[i].top, relu4[i].bottom, relu4[i].top, fc5[i].bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_3d_enc_network2(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = '3d_enc'
    network.input.append('ref')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('out')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('query')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('target')
    network.input_dim.append(batchsize)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)

    layers = []
    name_generator = self._connection_name_generator()

    inputs = [ 'ref', 'out', 'query', 'target' ]
    conv1 = [] # 64 -> 32
    relu1 = []
    conv2 = [] # 32 -> 16
    relu2 = []
    conv3 = []
    relu3 = []
    fc4 = []   # 16x16x24 -> 2048
    relu4 = []
    fc5 = []   # 2048 -> 300

    for inp in inputs:
      conv1.append(self._make_conv_layer(network, kernel_size=5, num_output=64, stride=2, pad=2,
                                         bias_value=0.1, shared_name='conv1'))
      conv1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))

      conv2.append(self._make_conv_layer(network, kernel_size=5, num_output=32, stride=2, pad=2,
                                         bias_value=0.1, shared_name='conv2'))
      relu2.append(self._make_relu_layer(network))

      conv3.append(self._make_conv_layer(network, kernel_size=3, num_output=32, stride=1, pad=1, bias_value=0.1, shared_name='conv3'))
      relu3.append(self._make_relu_layer(network))

      fc4.append(self._make_inner_product_layer(network, num_output=2048, shared_name='ip4'))
      relu4.append(self._make_relu_layer(network))

      fc5.append(self._make_inner_product_layer(network, num_output=300, shared_name='ip5'))
      fc5[-1].top.append('%s_hid' % inp)

    layers += conv1
    layers += relu1
    layers += conv2
    layers += relu2
    layers += conv3
    layers += relu3
    layers += fc4
    layers += relu4
    layers += fc5

    connections = []
    for i in range(4):
      connections.append((conv1[i].name, (conv1[i].top, relu1[i].bottom, relu1[i].top, conv2[i].bottom)))
      connections.append((conv2[i].name, (conv2[i].top, relu2[i].bottom, relu2[i].top, conv3[i].bottom)))
      connections.append((conv3[i].name, (conv3[i].top, relu3[i].bottom, relu3[i].top, fc4[i].bottom)))
      connections.append((fc4[i].name, (fc4[i].top, relu4[i].bottom, relu4[i].top, fc5[i].bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_shape_dec_network3(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'shapes_dec'
    network.input.append('trans')
    network.input_dim.append(batchsize)
    network.input_dim.append(512)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.input.append('hid')
    network.input_dim.append(batchsize)
    network.input_dim.append(512)
    network.input_dim.append(1)
    network.input_dim.append(1)

    # Compute transformation increment.
    t1 = self._make_inner_product_layer(network, num_output=256);
    t1.bottom.append('trans')

    h1 = self._make_inner_product_layer(network, num_output=256);
    h1.bottom.append('hid')

    fac = self._make_prod_layer(network)
    inc = self._make_inner_product_layer(network, num_output=512)

    sum = self._make_sum_layer(network)
    sum.bottom.append('hid')

    split = self._make_split_layer(network)

    # RGB decoder
    fc1 = self._make_inner_product_layer(network, num_output=2048)
    relu1 = self._make_relu_layer(network)
    fc2 = self._make_inner_product_layer(network, num_output=32*32*3)
    folding = self._make_folding_layer(network, 3, 32, 32)
    folding.top.append('rgb')

    layers = [ t1, h1, fac, inc, sum, split, fc1, relu1, fc2, folding ]
    name_generator = self._connection_name_generator()

    connections = []
    connections.append((t1.name, (t1.top, fac.bottom)))
    connections.append((h1.name, (h1.top, fac.bottom)))
    connections.append((fac.name, (fac.top, inc.bottom)))
    connections.append((inc.name, (inc.top, sum.bottom)))
    connections.append((sum.name, (sum.top, split.bottom)))
    connections.append((split.name, (split.top, fc1.bottom)))
    connections.append((fc1.name, (fc1.top, relu1.bottom, relu1.top, fc2.bottom)))
    connections.append((fc2.name, (fc2.top, folding.bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_shape_dec_network2(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'shapes_dec'
    network.input.append('trans')
    network.input_dim.append(batchsize)
    network.input_dim.append(512)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.input.append('hid')
    network.input_dim.append(batchsize)
    network.input_dim.append(512)
    network.input_dim.append(1)
    network.input_dim.append(1)

    # Compute transformation increment.
    t1 = self._make_inner_product_layer(network, num_output=256);
    t1.bottom.append('trans')

    h1 = self._make_inner_product_layer(network, num_output=256);
    h1.bottom.append('hid')

    fac = self._make_prod_layer(network)
    inc = self._make_inner_product_layer(network, num_output=512)

    sum = self._make_sum_layer(network)
    sum.bottom.append('hid')

    split = self._make_split_layer(network)

    # RGB decoder
    fc1 = self._make_inner_product_layer(network, num_output=2048)
    relu1 = self._make_relu_layer(network)
    fc2 = self._make_inner_product_layer(network, num_output=16*16*24)
    relu2 = self._make_relu_layer(network)
    folding = self._make_folding_layer(network, 24, 16, 16)
    # 16 -> 32
    upsample1 = self._make_upsampling_layer(network, stride=2)
    # 32 -> 32
    conv3 = self._make_conv_layer(network, kernel_size=5, num_output=3, stride=1, pad=2,
                                  bias_value=0.1)
    conv3.top.append('rgb')


    layers = [ t1, h1, fac, inc, sum, split, fc1, relu1, fc2, relu2, folding, upsample1, conv3 ]
    name_generator = self._connection_name_generator()

    connections = []
    connections.append((t1.name, (t1.top, fac.bottom)))
    connections.append((h1.name, (h1.top, fac.bottom)))
    connections.append((fac.name, (fac.top, inc.bottom)))
    connections.append((inc.name, (inc.top, sum.bottom)))
    connections.append((sum.name, (sum.top, split.bottom)))
    connections.append((split.name, (split.top, fc1.bottom)))
    connections.append((fc1.name, (fc1.top, relu1.bottom, relu1.top, fc2.bottom)))
    connections.append((fc2.name, (fc2.top, relu2.bottom, relu2.top, folding.bottom)))
    connections.append((folding.name, (folding.top, upsample1.bottom)))
    connections.append((upsample1.name, (upsample1.top, conv3.bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_3d_dec_small_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'dec_small'
    network.input.append('trans')
    network.input_dim.append(batchsize)
    network.input_dim.append(1024)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.input.append('hid')
    network.input_dim.append(batchsize)
    network.input_dim.append(1024)
    network.input_dim.append(1)
    network.input_dim.append(1)

    # Compute transformation increment.
    t1 = self._make_inner_product_layer(network, num_output=512);
    t1.bottom.append('trans')

    h1 = self._make_inner_product_layer(network, num_output=512);
    h1.bottom.append('hid')

    fac = self._make_sum_layer(network)
    fac_relu = self._make_relu_layer(network)
    inc = self._make_inner_product_layer(network, num_output=1024)

    sum = self._make_sum_layer(network)
    sum.bottom.append('hid')

    split = self._make_split_layer(network)
    split.top.append('hid_new')

    # RGB decoder
    fc1 = self._make_inner_product_layer(network, num_output=1024)
    relu1 = self._make_relu_layer(network)
    fc2 = self._make_inner_product_layer(network, num_output=2048)
    relu2 = self._make_relu_layer(network)
    fc3 = self._make_inner_product_layer(network, num_output=64*64*3)
    folding = self._make_folding_layer(network, 3, 64, 64)
    folding.top.append('rgb')

    # Mask decoder
    fc1_m = self._make_inner_product_layer(network, num_output=1024)
    relu1_m = self._make_relu_layer(network)
    fc2_m = self._make_inner_product_layer(network, num_output=2048)
    relu2_m = self._make_relu_layer(network)
    fc3_m = self._make_inner_product_layer(network, num_output=64*64*1)
    folding_m = self._make_folding_layer(network, 1, 64, 64)
    folding_m.top.append('mask')


    layers = [ t1, h1, fac, fac_relu, inc, sum, split, fc1, relu1, fc2, relu2, fc3, folding, fc1_m, relu1_m, fc2_m, relu2_m, fc3_m, folding_m ]
    name_generator = self._connection_name_generator()

    connections = []
    connections.append((t1.name, (t1.top, fac.bottom)))
    connections.append((h1.name, (h1.top, fac.bottom)))
    connections.append((fac.name, (fac.top, fac_relu.bottom, fac_relu.top, inc.bottom)))
    connections.append((inc.name, (inc.top, sum.bottom)))
    connections.append((sum.name, (sum.top, split.bottom)))

    connections.append((split.name, (split.top, fc1.bottom)))
    connections.append((fc1.name, (fc1.top, relu1.bottom, relu1.top, fc2.bottom)))
    connections.append((fc2.name, (fc2.top, relu2.bottom, relu2.top, fc3.bottom)))
    connections.append((fc3.name, (fc3.top,  folding.bottom)))

    connections.append((split.name, (split.top, fc1_m.bottom)))
    connections.append((fc1_m.name, (fc1_m.top, relu1_m.bottom, relu1_m.top, fc2_m.bottom)))
    connections.append((fc2_m.name, (fc2_m.top, relu2_m.bottom, relu2_m.top, fc3_m.bottom)))
    connections.append((fc3_m.name, (fc3_m.top, folding_m.bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_3d_dec_network2(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'dec_3d'
    network.input.append('trans')
    network.input_dim.append(batchsize)
    network.input_dim.append(300)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.input.append('hid')
    network.input_dim.append(batchsize)
    network.input_dim.append(300)
    network.input_dim.append(1)
    network.input_dim.append(1)
    network.input.append('switch')
    network.input_dim.append(batchsize)
    network.input_dim.append(300)
    network.input_dim.append(1)
    network.input_dim.append(1)

    # Compute transformation increment.
    t1 = self._make_inner_product_layer(network, num_output=100);
    t1.bottom.append('trans')

    h1 = self._make_inner_product_layer(network, num_output=100);
    h1.bottom.append('hid')

    fac = self._make_sum_layer(network)
    fac_relu = self._make_relu_layer(network)
    
    # Add FC layers to predict increment better.
    fh1 = self._make_inner_product_layer(network, num_output=100)
    fh1_relu = self._make_relu_layer(network)

    fh2 = self._make_inner_product_layer(network, num_output=100)
    fh2_relu = self._make_relu_layer(network)

    inc = self._make_inner_product_layer(network, num_output=300)
    prod = self._make_prod_layer(network)
    prod.bottom.append('switch')

    sum = self._make_sum_layer(network)
    sum.bottom.append('hid')

    split = self._make_split_layer(network)
    split.top.append('hid_new')

    # RGB decoder
    fc1 = self._make_inner_product_layer(network, num_output=1024)
    relu1 = self._make_relu_layer(network)
    fc2 = self._make_inner_product_layer(network, num_output=2048)
    relu2 = self._make_relu_layer(network)
    fc3 = self._make_inner_product_layer(network, num_output=16*16*24)
    relu3 = self._make_relu_layer(network)
    folding = self._make_folding_layer(network, 24, 16, 16)
    # 16 -> 32
    upsample1 = self._make_upsampling_layer(network, stride=2)
    # 32 -> (32 + 4 - 3 + 1 = 34)
    conv4 = self._make_conv_layer(network, kernel_size=3, num_output=64, stride=1, pad=2,
                                  bias_value=0.1)
    relu4 = self._make_relu_layer(network)
    # 34
    conv5 = self._make_conv_layer(network, kernel_size=3, num_output=48, stride=1, pad=1,
                                  bias_value=0.1)
    relu5 = self._make_relu_layer(network)
    # 34 -> 68
    upsample2 = self._make_upsampling_layer(network, stride=2)
    # 68-> 66
    conv6 = self._make_conv_layer(network, kernel_size=3, num_output=24, stride=1, pad=0,
                                  bias_value=0.1)
    relu6 = self._make_relu_layer(network)
    # 66 -> 64
    conv7 = self._make_conv_layer(network, kernel_size=3, num_output=3, stride=1, pad=0,
                                  bias_value=0.1)
    conv7.top.append('rgb')

    # Mask decoder
    fc1_m = self._make_inner_product_layer(network, num_output=1024)
    relu1_m = self._make_relu_layer(network)
    fc2_m = self._make_inner_product_layer(network, num_output=2048)
    relu2_m = self._make_relu_layer(network)
    fc3_m = self._make_inner_product_layer(network, num_output=16*16*24)
    relu3_m = self._make_relu_layer(network)
    folding_m = self._make_folding_layer(network, 24, 16, 16)
    # 16 -> 32
    upsample1_m = self._make_upsampling_layer(network, stride=2)
    # 32 -> (32 + 4 - 3 + 1 = 34)
    conv4_m = self._make_conv_layer(network, kernel_size=3, num_output=64, stride=1, pad=2,
                                  bias_value=0.1)
    relu4_m = self._make_relu_layer(network)
    # 34
    conv5_m = self._make_conv_layer(network, kernel_size=3, num_output=48, stride=1, pad=1,
                                  bias_value=0.1)
    relu5_m = self._make_relu_layer(network)
    # 34 -> 68
    upsample2_m = self._make_upsampling_layer(network, stride=2)
    # 68 -> 66
    conv6_m = self._make_conv_layer(network, kernel_size=3, num_output=24, stride=1, pad=0,
                                  bias_value=0.1)
    relu6_m = self._make_relu_layer(network)
    # 66 -> 64
    conv7_m = self._make_conv_layer(network, kernel_size=3, num_output=1, stride=1, pad=0,
                                  bias_value=0.1)
    conv7_m.top.append('mask')


    layers = [ t1, h1, fac, fac_relu, fh1, fh1_relu, fh2, fh2_relu, inc, prod, sum, split, fc1, relu1, fc2, relu2, fc3, relu3, folding, upsample1, conv4, relu4, conv5, relu5, upsample2, conv6, relu6, conv7, fc1_m, relu1_m, fc2_m, relu2_m, fc3_m, relu3_m, folding_m, upsample1_m, conv4_m, relu4_m, conv5_m, relu5_m, upsample2_m, conv6_m, relu6_m, conv7_m ]
    name_generator = self._connection_name_generator()

    connections = []
    connections.append((t1.name, (t1.top, fac.bottom)))
    connections.append((h1.name, (h1.top, fac.bottom)))
    connections.append((fac.name, (fac.top, fac_relu.bottom, fac_relu.top, fh1.bottom)))
    connections.append((fh1.name, (fh1.top, fh1_relu.bottom, fh1_relu.top, fh2.bottom)))
    connections.append((fh2.name, (fh2.top, fh2_relu.bottom, fh2_relu.top, inc.bottom)))
    connections.append((inc.name, (inc.top, prod.bottom)))
    connections.append((prod.name, (prod.top, sum.bottom)))
    connections.append((sum.name, (sum.top, split.bottom)))

    connections.append((split.name, (split.top, fc1.bottom)))
    connections.append((fc1.name, (fc1.top, relu1.bottom, relu1.top, fc2.bottom)))
    connections.append((fc2.name, (fc2.top, relu2.bottom, relu2.top, fc3.bottom)))
    connections.append((fc3.name, (fc3.top, relu3.bottom, relu3.top, folding.bottom)))
    connections.append((folding.name, (folding.top, upsample1.bottom)))
    connections.append((upsample1.name, (upsample1.top, conv4.bottom)))
    connections.append((conv4.name, (conv4.top, relu4.bottom, relu4.top, conv5.bottom)))
    connections.append((conv5.name, (conv5.top, relu5.bottom, relu5.top, upsample2.bottom)))
    connections.append((upsample2.name, (upsample2.top, conv6.bottom)))
    connections.append((conv6.name, (conv6.top, relu6.bottom, relu6.top, conv7.bottom)))

    connections.append((split.name, (split.top, fc1_m.bottom)))
    connections.append((fc1_m.name, (fc1_m.top, relu1_m.bottom, relu1_m.top, fc2_m.bottom)))
    connections.append((fc2_m.name, (fc2_m.top, relu2_m.bottom, relu2_m.top, fc3_m.bottom)))
    connections.append((fc3_m.name, (fc3_m.top, relu3_m.bottom, relu3_m.top, folding_m.bottom)))
    connections.append((folding_m.name, (folding_m.top, upsample1_m.bottom)))
    connections.append((upsample1_m.name, (upsample1_m.top, conv4_m.bottom)))
    connections.append((conv4_m.name, (conv4_m.top, relu4_m.bottom, relu4_m.top, conv5_m.bottom)))
    connections.append((conv5_m.name, (conv5_m.top, relu5_m.bottom, relu5_m.top, upsample2_m.bottom)))
    connections.append((upsample2_m.name, (upsample2_m.top, conv6_m.bottom)))
    connections.append((conv6_m.name, (conv6_m.top, relu6_m.bottom, relu6_m.top, conv7_m.bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network

  def _build_3d_dec_network(self, wtype='xavier', std=0.01, batchsize=25):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'sprites_dec'
    network.input.append('hid')
    network.input_dim.append(batchsize)
    network.input_dim.append(1024)
    network.input_dim.append(1)
    network.input_dim.append(1)

    # RGB decoder
    fc1 = self._make_inner_product_layer(network, num_output=2048)
    fc1.bottom.append('hid')
    relu1 = self._make_relu_layer(network)
    fc2 = self._make_inner_product_layer(network, num_output=16*16*24)
    relu2 = self._make_relu_layer(network)
    folding = self._make_folding_layer(network, 24, 16, 16)
    # 16 -> 32
    upsample1 = self._make_upsampling_layer(network, stride=2)
    # 32 -> (32 + 6 - 5 + 1 = 34)
    conv3 = self._make_conv_layer(network, kernel_size=5, num_output=48, stride=1, pad=3,
                                  bias_value=0.1)
    relu3 = self._make_relu_layer(network)
    # 34 -> 68
    upsample2 = self._make_upsampling_layer(network, stride=2)
    # 68-> 64
    conv4 = self._make_conv_layer(network, kernel_size=5, num_output=3, stride=1, pad=0,
                                  bias_value=0.1)
    conv4.top.append('rgb')

    # Mask decoder
    fc1_m = self._make_inner_product_layer(network, num_output=2048)
    fc1_m.bottom.append('hid')
    relu1_m = self._make_relu_layer(network)
    fc2_m = self._make_inner_product_layer(network, num_output=16*16*24)
    relu2_m = self._make_relu_layer(network)
    folding_m = self._make_folding_layer(network, 24, 16, 16)
    # 16 -> 32
    upsample1_m = self._make_upsampling_layer(network, stride=2)
    # 32 -> (32 + 6 - 5 + 1 = 34)
    conv3_m = self._make_conv_layer(network, kernel_size=5, num_output=24, stride=1, pad=3,
                                    bias_value=0.1)
    relu3_m = self._make_relu_layer(network)
    # 34 -> 68
    upsample2_m = self._make_upsampling_layer(network, stride=2)
    # 68-> 64
    conv4_m = self._make_conv_layer(network, kernel_size=5, num_output=1, stride=1, pad=0,
                                    bias_value=0.1)
    conv4_m.top.append('mask')

    layers = [ fc1, relu1, fc2, relu2, folding, upsample1, conv3, relu3, upsample2, conv4,
               fc1_m, relu1_m, fc2_m, relu2_m, folding_m, upsample1_m, conv3_m, relu3_m, upsample2_m, conv4_m ]
    name_generator = self._connection_name_generator()

    connections = []
    connections.append((fc1.name, (fc1.top, relu1.bottom, relu1.top, fc2.bottom)))
    connections.append((fc2.name, (fc2.top, relu2.bottom, relu2.top, folding.bottom)))
    connections.append((folding.name, (folding.top, upsample1.bottom)))
    connections.append((upsample1.name, (upsample1.top, conv3.bottom)))
    connections.append((conv3.name, (conv3.top, relu3.bottom, relu3.top, upsample2.bottom)))
    connections.append((upsample2.name, (upsample2.top, conv4.bottom)))

    connections.append((fc1_m.name, (fc1_m.top, relu1_m.bottom, relu1_m.top, fc2_m.bottom)))
    connections.append((fc2_m.name, (fc2_m.top, relu2_m.bottom, relu2_m.top, folding_m.bottom)))
    connections.append((folding_m.name, (folding_m.top, upsample1_m.bottom)))
    connections.append((upsample1_m.name, (upsample1_m.top, conv3_m.bottom)))
    connections.append((conv3_m.name, (conv3_m.top, relu3_m.bottom, relu3_m.top, upsample2_m.bottom)))
    connections.append((upsample2_m.name, (upsample2_m.top, conv4_m.bottom)))

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network



  def build_network(self, netname, batchsize=20):
    """main method."""

    if netname == 'analogy':
      network = self._build_analogy_network(batchsize=batchsize)
    elif netname == 'analogy_small':
      network = self._build_small_analogy_network(batchsize=batchsize)
    elif netname == 'analogy_sprite_conv_enc':
      network = self._build_sprite_conv_enc_network(batchsize=batchsize)
    elif netname == 'analogy_sprite_conv_dec':
      network = self._build_sprite_conv_dec_network(batchsize=batchsize)
    elif netname == 'analogy_sprite_imp_enc':
      network = self._build_sprite_imp_enc_network(batchsize=batchsize)
    elif netname == 'analogy_sprite_imp_dec':
      network = self._build_sprite_imp_dec_network(batchsize=batchsize)
    elif netname == 'analogy_sprite_conv_enc2':
      network = self._build_sprite_conv_enc2_network(batchsize=batchsize)
    elif netname == 'analogy_sprite_conv_dec2':
      network = self._build_sprite_conv_dec2_network(batchsize=batchsize)
    elif netname == 'analogy_sprite_enc':
      network = self._build_sprite_enc_network(batchsize=batchsize)
    elif netname == 'analogy_sprite_dec':
      network = self._build_sprite_dec_network(batchsize=batchsize)
    elif netname == 'analogy_shape_enc':
      network = self._build_shape_enc_network(batchsize=batchsize)
    elif netname == 'analogy_shape_dec':
      network = self._build_shape_dec_network(batchsize=batchsize)
    elif netname == 'analogy_3d_enc':
      network = self._build_3d_enc_network(batchsize=batchsize)
    elif netname == 'analogy_3d_enc_small':
      network = self._build_3d_enc_small_network(batchsize=batchsize)
    elif netname == 'analogy_3d_dec':
      network = self._build_3d_dec_network(batchsize=batchsize)
    elif netname == 'analogy_3d_dec_small':
      network = self._build_3d_dec_small_network(batchsize=batchsize)
    elif netname == 'analogy_3d_dec2':
      network = self._build_3d_dec_network2(batchsize=batchsize)
    elif netname == 'analogy_3d_enc2':
      network = self._build_3d_enc_network2(batchsize=batchsize)
    elif netname == 'analogy_shape_enc2':
      network = self._build_shape_enc_network2(batchsize=batchsize)
    elif netname == 'analogy_shape_dec2':
      network = self._build_shape_dec_network2(batchsize=batchsize)
    elif netname == 'analogy_shape_enc3':
      network = self._build_shape_enc_network3(batchsize=batchsize)
    elif netname == 'analogy_shape_dec3':
      network = self._build_shape_dec_network3(batchsize=batchsize)
    else:
      print('unknown netname: %s' % netname)
      return

    network_filename = '%s.prototxt' % netname
    print network
    with open(network_filename, 'w') as network_file:
      network_file.write(text_format.MessageToString(network))
    return Net(network_filename)


if __name__ == '__main__':
  __Network_builder__ = NetworkBuilder()
  #__Network_builder__.build_network(netname='analogy', batchsize=20)
  #__Network_builder__.build_network(netname='analogy_small', batchsize=100)
  #__Network_builder__.build_network(netname='analogy_sprite_enc', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_sprite_dec', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_shape_enc', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_shape_dec', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_sprite_conv_enc', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_sprite_conv_dec', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_sprite_conv_enc2', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_sprite_conv_dec2', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_sprite_imp_enc', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_sprite_imp_dec', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_3d_enc', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_3d_dec', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_3d_enc2', batchsize=25)
  __Network_builder__.build_network(netname='analogy_3d_dec2', batchsize=25)
  #  __Network_builder__.build_network(netname='analogy_shape_enc2', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_shape_dec2', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_shape_enc3', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_shape_dec3', batchsize=25)
  #__Network_builder__.build_network(netname='analogy_3d_dec_small', batchsize=25)

