# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
# pylint: disable=no-self-use
# pylint: disable=too-few-public-methods

from caffe.proto import caffe_pb2
from caffe import Net

from google.protobuf import text_format
import sys

# pylint: disable=invalid-name
# pylint: disable=no-member
LayerType = caffe_pb2.LayerParameter.LayerType
EltwiseOp = caffe_pb2.EltwiseParameter.EltwiseOp
PoolMethod = caffe_pb2.PoolingParameter.PoolMethod
DBType = caffe_pb2.DataParameter.DB
# pylint: enable=invalid-name
# pylint: enable=no-member

batchsize = 1;

class NetworkBuilder(object):

  def __init__(self, training_batch_size=20, testing_batch_size=20, **kwargs):
    self.training_batch_size = training_batch_size
    self.testing_batch_size = testing_batch_size
    self.other_args = kwargs


  def _make_inception(self, network, x1x1, x3x3r, x3x3, x5x5r, x5x5, proj,
                      name_generator):
    """Make Inception submodule."""

    layers = []

    split = self._make_split_layer(network)
    layers.append(split)

    context1 = self._make_conv_layer(network, kernel_size=1, num_output=x1x1,
                                     bias_value=0)
    layers.append(context1)

    relu1 = self._make_relu_layer(network)
    layers.append(relu1)

    context2a = self._make_conv_layer(network, kernel_size=1, num_output=x3x3r,
                                      bias_value=0)
    layers.append(context2a)

    relu2a = self._make_relu_layer(network)
    layers.append(relu2a)

    context2b = self._make_conv_layer(network, kernel_size=3, num_output=x3x3,
                                      pad=1)
    layers.append(context2b)

    relu2b = self._make_relu_layer(network)
    layers.append(relu2b)

    context3a = self._make_conv_layer(network, kernel_size=1, num_output=x5x5r,
                                      bias_value=0)
    layers.append(context3a)

    relu3a = self._make_relu_layer(network)
    layers.append(relu3a)

    context3b = self._make_conv_layer(network, kernel_size=5, num_output=x5x5,
                                      pad=2)
    layers.append(context3b)

    relu3b = self._make_relu_layer(network)
    layers.append(relu3b)

    context4a = self._make_maxpool_layer(network, kernel_size=3)
    layers.append(context4a)

    relu4a = self._make_relu_layer(network)
    layers.append(relu4a)

    context4b = self._make_conv_layer(network, kernel_size=1, num_output=proj,
                                      pad=1, bias_value=0)
    layers.append(context4b)

    relu4b = self._make_relu_layer(network)
    layers.append(relu4b)

    concat = self._make_concat_layer(network)
    layers.append(concat)

    connections = [
      (split.name, (split.top, context1.bottom)),
      (split.name, (split.top, context2a.bottom)),
      (split.name, (split.top, context3a.bottom)),
      (split.name, (split.top, context4a.bottom)),
      (context2a.name,
          (context2a.top, relu2a.bottom, relu2a.top, context2b.bottom)),
      (context3a.name,
          (context3a.top, relu3a.bottom, relu3a.top, context3b.bottom)),
      (context4a.name,
          (context4a.top, relu4a.bottom, relu4a.top, context4b.bottom)),
      (context1.name, (context1.top, relu1.bottom, relu1.top, concat.bottom)),
      (context2b.name,
          (context2b.top, relu2b.bottom, relu2b.top, concat.bottom)),
      (context3b.name,
          (context3b.top, relu3b.bottom, relu3b.top, concat.bottom)),
      (context4b.name,
          (context4b.top, relu4b.bottom, relu4b.top, concat.bottom)),
    ]

    for connection in connections:
      self._tie(connection, name_generator)

    return layers

  def _make_prod_layer(self, network, coeff=None):
    layer = network.layers.add()
    layer.name = 'prod'
    layer.type = LayerType.Value('ELTWISE')
    params = layer.eltwise_param
    params.operation = EltwiseOp.Value('PROD')
    if coeff:
      for c in coeff:
        params.coeff.append(c)
    return layer

  def _make_sum_layer(self, network, coeff=None):
    layer = network.layers.add()
    layer.name = 'sum'
    layer.type = LayerType.Value('ELTWISE')
    params = layer.eltwise_param
    params.operation = EltwiseOp.Value('SUM')
    if coeff:
      for c in coeff:
        params.coeff.append(c)
    return layer

  def _make_upsampling_layer(self, network, stride):
    layer = network.layers.add()
    layer.name = 'upsample'
    layer.type = LayerType.Value('UPSAMPLING')
    params = layer.upsampling_param
    params.kernel_size = stride
    return layer

  def _make_folding_layer(self, network, channels, height, width, prefix=''):
    layer = network.layers.add()
    layer.name = '%sfolding' % (prefix)
    layer.type = LayerType.Value('FOLDING')
    params = layer.folding_param
    params.channels_folded = channels
    params.height_folded = height
    params.width_folded = width
    return layer

  def _make_conv_layer(self, network, kernel_size, num_output, stride=1, pad=0,
                       bias_value=0.1, shared_name=None, wtype='xavier', std=0.01):
    """Make convolution layer."""

    layer = network.layers.add()
    layer.name = 'conv_%dx%d_%d' % (kernel_size, kernel_size, stride)

    layer.type = LayerType.Value('CONVOLUTION')
    params = layer.convolution_param
    params.num_output = num_output
    params.kernel_size = kernel_size
    params.stride = stride
    params.pad = pad
    weight_filler = params.weight_filler
    weight_filler.type = wtype
    if weight_filler.type == 'gaussian':
      weight_filler.mean = 0
      weight_filler.std = std
    bias_filler = params.bias_filler
    bias_filler.type = 'constant'
    bias_filler.value = bias_value

    layer.blobs_lr.append(1)
    layer.blobs_lr.append(2)

    layer.weight_decay.append(1)
    layer.weight_decay.append(0)

    if shared_name:
      layer.param.append('%s_w' % shared_name)
      layer.param.append('%s_b' % shared_name)

    return layer

  def _make_maxpool_layer(self, network, kernel_size, stride=1):
    """Make max pooling layer."""

    layer = network.layers.add()
    layer.name = 'maxpool_%dx%d_%d' % (kernel_size, kernel_size, stride)

    layer.type = LayerType.Value('POOLING')
    params = layer.pooling_param
    params.pool = PoolMethod.Value('MAX')
    params.kernel_size = kernel_size
    params.stride = stride

    return layer

  def _make_avgpool_layer(self, network, kernel_size, stride=1):
    """Make average pooling layer."""

    layer = network.layers.add()
    layer.name = 'avgpool_%dx%d_%d' % (kernel_size, kernel_size, stride)

    layer.type = LayerType.Value('POOLING')
    params = layer.pooling_param
    params.pool = PoolMethod.Value('AVE')
    params.kernel_size = kernel_size
    params.stride = stride

    return layer

  def _make_lrn_layer(self, network, name='lrn'):
    """Make local response normalization layer."""

    layer = network.layers.add()
    layer.name = name

    layer.type = LayerType.Value('LRN')
    params = layer.lrn_param
    params.local_size = 5
    params.alpha = 0.0001
    params.beta = 0.75

    return layer

  def _make_concat_layer(self, network, concat_dim=1):
    """Make depth concatenation layer."""

    layer = network.layers.add()
    layer.name = 'concat'

    layer.type = LayerType.Value('CONCAT')
    params = layer.concat_param
    params.concat_dim = concat_dim

    return layer

  def _make_dropout_layer(self, network, dropout_ratio=0.5):
    """Make dropout layer."""

    layer = network.layers.add()
    layer.name = 'dropout'

    layer.type = LayerType.Value('DROPOUT')
    params = layer.dropout_param
    params.dropout_ratio = dropout_ratio

    return layer

  def _make_tensor_layer(self, network, num_output, weight_lr=1,
                         bias_lr=2, bias_value=0.1, prefix='',
                         shared_name=None,
                         wtype='xavier', std=0.01):
    """Make tensor product layer."""

    layer = network.layers.add()
    layer.name = '%stensor_product' % prefix

    layer.type = LayerType.Value('TENSOR_PRODUCT')
    params = layer.inner_product_param
    params.num_output = num_output
    weight_filler = params.weight_filler
    weight_filler.type = wtype
    if wtype == 'gaussian':
      weight_filler.mean = 0
      weight_filler.std = std
    bias_filler = params.bias_filler
    bias_filler.type = 'constant'
    bias_filler.value = bias_value

    layer.blobs_lr.append(weight_lr)
    layer.blobs_lr.append(bias_lr)

    layer.weight_decay.append(1)
    layer.weight_decay.append(0)

    if shared_name:
      layer.param.append('%s_w' % shared_name)
      layer.param.append('%s_b' % shared_name)

    return layer

  def _make_inner_product_layer(self, network, num_output, weight_lr=1,
                                bias_lr=2, bias_value=0.1, prefix='',
                                shared_name=None,
                                wtype='xavier', std=0.01):
    """Make inner product layer."""

    layer = network.layers.add()
    layer.name = '%sinner_product' % prefix

    layer.type = LayerType.Value('INNER_PRODUCT')
    params = layer.inner_product_param
    params.num_output = num_output
    weight_filler = params.weight_filler
    weight_filler.type = wtype
    if wtype == 'gaussian':
      weight_filler.mean = 0
      weight_filler.std = std
    bias_filler = params.bias_filler
    bias_filler.type = 'constant'
    bias_filler.value = bias_value

    layer.blobs_lr.append(weight_lr)
    layer.blobs_lr.append(bias_lr)

    layer.weight_decay.append(1)
    layer.weight_decay.append(0)

    if shared_name:
      layer.param.append('%s_w' % shared_name)
      layer.param.append('%s_b' % shared_name)

    return layer

  def _make_split_layer(self, network):
    """Make split layer."""

    layer = network.layers.add()
    layer.name = 'split'

    layer.type = LayerType.Value('SPLIT')

    return layer

  def _make_relu_layer(self, network):
    """Make ReLU layer."""

    layer = network.layers.add()
    layer.name = 'relu'

    layer.type = LayerType.Value('RELU')

    return layer

  def _tie(self, layers, name_generator):
    """Generate a named connection between layer endpoints."""

    name = 'ep_%s_%d' % (layers[0], name_generator.next())
    for layer in layers[1]:
      layer.append(name)

  def _connection_name_generator(self):
    """Generate a unique id."""

    index = 0
    while True:
      yield index
      index += 1


  def _build_rnn_network_2layer(self, wtype='xavier', std=0.01, batchsize=1,
                                numstep=4):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'analogy_rnn'
    network.input.append('ref')
    network.input_dim.append(1)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('out')
    network.input_dim.append(1)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('query')
    network.input_dim.append(1)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('target')
    network.input_dim.append(numstep)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('diffid')
    network.input_dim.append(1)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)

    layers = []
    name_generator = self._connection_name_generator()

    conv1 = []
    relu1 = []
    conv2 = []
    relu2 = []
    fc1 = []
    relu4 = []
    fc2 = []
    relu5 = []
    relu5_split = []
    fc1_view = []
    view_split = []
    inc_fc2 = []
    inc_relu2 = []
    inc_fc3 = []
    sum_view = []
    sum_trans = []
    relu1_view = []
    relu1_trans = []
    sum_view_split = []
    concat = []
    concat_split = []
    dec_fc1 = []
    dec_relu1 = []
    dec_fc2 = []
    dec_relu2 = []
    dec_relu2_split = []
    dec_img_fc1 = []
    dec_img_relu1 = []
    dec_img_fold = []
    dec_img_up1 = []
    dec_img_conv1 = []
    dec_img_relu2 = []
    dec_img_up2 = []
    dec_img_conv2 = []
    dec_mask_fc1 = []
    dec_mask_relu1 = []
    dec_mask_fold = []
    dec_mask_up1 = []
    dec_mask_conv1 = []
    dec_mask_relu2 = []
    dec_mask_up2 = []
    dec_mask_conv2 = []

    # Encode ref, out, query.
    inputs = [ 'ref', 'out', 'query', 'target', 'diffid' ]
    for inp in inputs:
      conv1.append(self._make_conv_layer(network, kernel_size=5, stride=2, pad=2, num_output=64, shared_name='conv1'))
      conv1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))

      conv2.append(self._make_conv_layer(network, kernel_size=5, stride=2, pad=2, num_output=72, shared_name='conv2'))
      relu2.append(self._make_relu_layer(network))

      fc1.append(self._make_inner_product_layer(network, num_output=1024, shared_name='fc1'))
      relu4.append(self._make_relu_layer(network))

      fc2.append(self._make_inner_product_layer(network, num_output=1024, shared_name='fc2'))
      relu5.append(self._make_relu_layer(network))
      relu5_split.append(self._make_split_layer(network))
      relu5_split[-1].name = '%s_relu5_split' % inp

      if inp=='query':
        fc1_id = self._make_inner_product_layer(network, num_output=512, shared_name='fc1_id')
        id_split = self._make_split_layer(network)
        id_split.name = 'query_id_split'
      elif inp=='diffid':
        fc1_diff_id = self._make_inner_product_layer(network, num_output=512, shared_name='fc1_id')
        fc1_diff_id.name = 'fc1_diff_id'
        continue # don't care about view from this input.
      elif inp=='target':
        target_fc1_id = self._make_inner_product_layer(network, num_output=512, shared_name='fc1_id')

      fc1_view.append(self._make_inner_product_layer(network, num_output=256, shared_name='fc1_view'))
      fc1_view[-1].name = '%s_fc1_view' % inp

    # Extract target hidden units.
    target_hid = self._make_concat_layer(network)
    target_hid.top.append('target_hid')

    # Extract trans = out_view - ref_view
    trans_view = self._make_sum_layer(network, coeff=[-1,1])
    trans_view_fc1 = self._make_inner_product_layer(network, num_output=256)
    trans_view_fc1_split = self._make_split_layer(network)
    trans_view_fc1_split.name = 'trans_view_fc1_split'

    # Make the encoding connections.
    connections = []
    for idx,inp in enumerate(inputs):
      connections.append((conv1[idx].name, (conv1[idx].top, relu1[idx].bottom, relu1[idx].top, conv2[idx].bottom)))
      connections.append((conv2[idx].name, (conv2[idx].top, relu2[idx].bottom, relu2[idx].top, fc1[idx].bottom)))
      connections.append((fc1[idx].name, (fc1[idx].top, relu4[idx].bottom, relu4[idx].top, fc2[idx].bottom)))
      connections.append((fc2[idx].name, (fc2[idx].top, relu5[idx].bottom)))
      connections.append((relu5[idx].name, (relu5[idx].top, relu5_split[idx].bottom)))

      if inp != 'diffid': # diffid has no view output.
        connections.append((relu5_split[idx].name, (relu5_split[idx].top, fc1_view[idx].bottom)))
      if inp=='query':
        connections.append((relu5_split[idx].name, (relu5_split[idx].top, fc1_id.bottom)))
        connections.append((fc1_id.name, (fc1_id.top, id_split.bottom)))
      elif inp=='diffid':
        connections.append((relu5_split[idx].name, (relu5_split[idx].top, fc1_diff_id.bottom)))
      elif inp=='target':
        connections.append((relu5_split[idx].name, (relu5_split[idx].top, target_fc1_id.bottom)))
        connections.append((target_fc1_id.name, (target_fc1_id.top, target_hid.bottom)))
        connections.append((fc1_view[idx].name, (fc1_view[idx].top, target_hid.bottom)))


    # Connect view transformation.
    connections.append((fc1_view[0].name, (fc1_view[0].top, trans_view.bottom)))
    connections.append((fc1_view[1].name, (fc1_view[1].top, trans_view.bottom)))
    connections.append((trans_view.name, (trans_view.top, trans_view_fc1.bottom)))
    connections.append((trans_view_fc1.name, (trans_view_fc1.top, trans_view_fc1_split.bottom)))
    for t in range(numstep+1):
      if t < numstep:
        # Action.
        view_split.append(self._make_split_layer(network))
        sum_trans.append(self._make_sum_layer(network))
        relu1_trans.append(self._make_relu_layer(network))
        inc_fc2.append(self._make_inner_product_layer(network, num_output=256, shared_name='inc_fc2'))
        inc_relu2.append(self._make_relu_layer(network))
        inc_fc3.append(self._make_inner_product_layer(network, num_output=256, shared_name='inc_fc3'))
        sum_view.append(self._make_sum_layer(network))
        sum_view_split.append(self._make_split_layer(network))

      ###### Action connections ######
      # current view -> view_split. 
      # Go to both transformation increment and add to next step.
      if t==0: # Take view from query encoder.
        connections.append((fc1_view[2].name, (fc1_view[2].top, view_split[-1].bottom)))
      elif t<numstep:    # Take view from previous time step.
        connections.append((sum_view_split[-2].name, (sum_view_split[-2].top, view_split[-1].bottom)))

      if t<numstep:
        connections.append((view_split[-1].name, (view_split[-1].top, sum_trans[-1].bottom)))
        connections.append((trans_view_fc1_split.name, (trans_view_fc1_split.top, sum_trans[-1].bottom)))
        connections.append((sum_trans[-1].name, (sum_trans[-1].top, relu1_trans[-1].bottom, relu1_trans[-1].top, inc_fc2[-1].bottom)))
        connections.append((inc_fc2[-1].name, (inc_fc2[-1].top, inc_relu2[-1].bottom, inc_relu2[-1].top, inc_fc3[-1].bottom)))
        connections.append((inc_fc3[-1].name, (inc_fc3[-1].top, sum_view[-1].bottom)))
        connections.append((view_split[-1].name, (view_split[-1].top, sum_view[-1].bottom)))
        connections.append((sum_view[-1].name, (sum_view[-1].top, sum_view_split[-1].bottom)))

      # Decoder.
      concat.append(self._make_concat_layer(network))
      concat_split.append(self._make_split_layer(network))
      dec_fc1.append(self._make_inner_product_layer(network, num_output=1024, shared_name='dec_fc1'))
      dec_relu1.append(self._make_relu_layer(network))
      dec_fc2.append(self._make_inner_product_layer(network, num_output=1024, shared_name='dec_fc2'))
      dec_relu2.append(self._make_relu_layer(network))
      dec_relu2_split.append(self._make_split_layer(network))
      # Dec img path.
      dec_img_fc1.append(self._make_inner_product_layer(network, num_output=12288, shared_name='dec_img_fc1'))
      dec_img_relu1.append(self._make_relu_layer(network))
      dec_img_fold.append(self._make_folding_layer(network,48,16,16))
      dec_img_up1.append(self._make_upsampling_layer(network,stride=2))
      dec_img_conv1.append(self._make_conv_layer(network, kernel_size=5, stride=1, pad=2, num_output=72, shared_name='dec_img_conv1'))
      dec_img_relu2.append(self._make_relu_layer(network))
      dec_img_up2.append(self._make_upsampling_layer(network,stride=2))
      dec_img_conv2.append(self._make_conv_layer(network, kernel_size=5, stride=1, pad=2, num_output=3, shared_name='dec_img_conv2'))
      dec_img_conv2[-1].name = 'dec_img_conv2_t%d' % t
      # Dec mask path.
      dec_mask_fc1.append(self._make_inner_product_layer(network, num_output=8192, shared_name='dec_mask_fc1'))
      dec_mask_relu1.append(self._make_relu_layer(network))
      dec_mask_fold.append(self._make_folding_layer(network,32,16,16))
      dec_mask_up1.append(self._make_upsampling_layer(network,stride=2))
      dec_mask_conv1.append(self._make_conv_layer(network, kernel_size=5, stride=1, pad=2, num_output=64, shared_name='dec_mask_conv1'))
      dec_mask_relu2.append(self._make_relu_layer(network))
      dec_mask_up2.append(self._make_upsampling_layer(network,stride=2))
      dec_mask_conv2.append(self._make_conv_layer(network, kernel_size=5, stride=1, pad=2, num_output=1, shared_name='dec_mask_conv2'))
      dec_mask_conv2[-1].name = 'dec_mask_conv2_t%d' % t

      # dec connections.
      if t==numstep:
        # Diff ID decoder.
        connections.append((fc1_diff_id.name, (fc1_diff_id.top, concat[-1].bottom)))
        connections.append((view_split[0].name, (view_split[0].top, concat[-1].bottom)))
      else:
        # Normal decoder.
        connections.append((id_split.name, (id_split.top, concat[-1].bottom)))
        connections.append((sum_view_split[-1].name, (sum_view_split[-1].top, concat[-1].bottom)))
      connections.append((concat[-1].name, (concat[-1].top, concat_split[-1].bottom)))
      connections.append((concat_split[-1].name, (concat_split[-1].top, dec_fc1[-1].bottom)))
      connections.append((dec_fc1[-1].name, (dec_fc1[-1].top, dec_relu1[-1].bottom, dec_relu1[-1].top, dec_fc2[-1].bottom)))
      connections.append((dec_fc2[-1].name, (dec_fc2[-1].top, dec_relu2[-1].bottom)))
      connections.append((dec_relu2[-1].name, (dec_relu2[-1].top, dec_relu2_split[-1].bottom)))
      # dec image connections.
      connections.append((dec_relu2_split[-1].name, (dec_relu2_split[-1].top, dec_img_fc1[-1].bottom)))
      connections.append((dec_img_fc1[-1].name, (dec_img_fc1[-1].top, dec_img_relu1[-1].bottom, dec_img_relu1[-1].top, dec_img_fold[-1].bottom)))
      connections.append((dec_img_fold[-1].name, (dec_img_fold[-1].top, dec_img_up1[-1].bottom)))
      connections.append((dec_img_up1[-1].name, (dec_img_up1[-1].top, dec_img_conv1[-1].bottom)))
      connections.append((dec_img_conv1[-1].name, (dec_img_conv1[-1].top, dec_img_relu2[-1].bottom, dec_img_relu2[-1].top, dec_img_up2[-1].bottom)))
      connections.append((dec_img_up2[-1].name, (dec_img_up2[-1].top, dec_img_conv2[-1].bottom)))
      # dec mask connections.
      connections.append((dec_relu2_split[-1].name, (dec_relu2_split[-1].top, dec_mask_fc1[-1].bottom)))
      connections.append((dec_mask_fc1[-1].name, (dec_mask_fc1[-1].top, dec_mask_relu1[-1].bottom, dec_mask_relu1[-1].top, dec_mask_fold[-1].bottom)))
      connections.append((dec_mask_fold[-1].name, (dec_mask_fold[-1].top, dec_mask_up1[-1].bottom)))
      connections.append((dec_mask_up1[-1].name, (dec_mask_up1[-1].top, dec_mask_conv1[-1].bottom)))
      connections.append((dec_mask_conv1[-1].name, (dec_mask_conv1[-1].top, dec_mask_relu2[-1].bottom, dec_mask_relu2[-1].top, dec_mask_up2[-1].bottom)))
      connections.append((dec_mask_up2[-1].name, (dec_mask_up2[-1].top, dec_mask_conv2[-1].bottom)))

    layers = [ fc1_id, fc1_diff_id, id_split, trans_view, trans_view_fc1, trans_view_fc1_split, target_hid, target_fc1_id ]
    layers += conv1
    layers += conv2
    layers += relu1
    layers += relu2
    layers += fc1
    layers += relu4
    layers += fc2
    layers += relu5
    layers += relu5_split
    layers += fc1_view
    layers += view_split
    layers += relu1_view
    layers += inc_fc2
    layers += inc_relu2
    layers += inc_fc3
    layers += sum_view
    layers += sum_trans
    layers += relu1_trans
    layers += concat
    layers += concat_split
    layers += sum_view_split
    layers += dec_fc1
    layers += dec_relu1
    layers += dec_fc2
    layers += dec_relu2
    layers += dec_relu2_split
    layers += dec_img_fc1
    layers += dec_img_relu1
    layers += dec_img_fold
    layers += dec_img_up1
    layers += dec_img_conv1
    layers += dec_img_relu2
    layers += dec_img_up2
    layers += dec_img_conv2
    layers += dec_mask_fc1
    layers += dec_mask_relu1
    layers += dec_mask_fold
    layers += dec_mask_up1
    layers += dec_mask_conv1
    layers += dec_mask_relu2
    layers += dec_mask_up2
    layers += dec_mask_conv2

    # Final concat.
    dec_img_cat = self._make_concat_layer(network, concat_dim=0)
    dec_img_cat.top.append('dec_img_cat')
    dec_mask_cat = self._make_concat_layer(network, concat_dim=0)
    dec_mask_cat.top.append('dec_mask_cat')
    for l in dec_img_conv2:
      connections.append((l.name, (l.top, dec_img_cat.bottom)))
    for l in dec_mask_conv2:
      connections.append((l.name, (l.top, dec_mask_cat.bottom)))
    layers.append(dec_img_cat)
    layers.append(dec_mask_cat)

    # Final hidden output.
    if numstep > 1:
      hid_out = self._make_concat_layer(network, concat_dim=0)
      hid_out.top.append('hid_out')
      layers.append(hid_out)
      for idx,l in enumerate(concat_split[:-1]):
        l.name = 'hid_out_t%d' % idx
        connections.append((l.name, (l.top, hid_out.bottom)))
    else:
      concat_split[0].top.append('hid_out')

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network


  def _build_rnn_network(self, wtype='xavier', std=0.01, batchsize=1, numstep=4):
    network = caffe_pb2.NetParameter()
    network.force_backward = True
    network.name = 'analogy_rnn'
    network.input.append('ref')
    network.input_dim.append(1)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('out')
    network.input_dim.append(1)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('query')
    network.input_dim.append(1)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('target')
    network.input_dim.append(numstep)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)
    network.input.append('diffid')
    network.input_dim.append(1)
    network.input_dim.append(3)
    network.input_dim.append(64)
    network.input_dim.append(64)

    layers = []
    name_generator = self._connection_name_generator()

    conv1 = []
    relu1 = []
    conv2 = []
    relu2 = []
    conv3 = []
    relu3 = []
    fc1 = []
    relu4 = []
    fc2 = []
    relu5 = []
    relu5_split = []
    fc1_view = []
    view_split = []
    inc_fc2 = []
    inc_relu2 = []
    inc_fc3 = []
    sum_view = []
    sum_trans = []
    relu1_view = []
    relu1_trans = []
    sum_view_split = []
    concat = []
    concat_split = []
    dec_fc1 = []
    dec_relu1 = []
    dec_fc2 = []
    dec_relu2 = []
    dec_relu2_split = []
    dec_img_fc1 = []
    dec_img_relu1 = []
    dec_img_fold = []
    dec_img_up1 = []
    dec_img_conv1 = []
    dec_img_relu2 = []
    dec_img_up2 = []
    dec_img_conv2 = []
    dec_img_relu3 = []
    dec_img_up3 = []
    dec_img_conv3 = []
    dec_mask_fc1 = []
    dec_mask_relu1 = []
    dec_mask_fold = []
    dec_mask_up1 = []
    dec_mask_conv1 = []
    dec_mask_relu2 = []
    dec_mask_up2 = []
    dec_mask_conv2 = []
    dec_mask_relu3 = []
    dec_mask_up3 = []
    dec_mask_conv3 = []

    # Encode ref, out, query.
    inputs = [ 'ref', 'out', 'query', 'target', 'diffid' ]
    for inp in inputs:
      conv1.append(self._make_conv_layer(network, kernel_size=5, stride=2, pad=2, num_output=64, shared_name='conv1'))
      conv1[-1].bottom.append(inp)
      relu1.append(self._make_relu_layer(network))

      conv2.append(self._make_conv_layer(network, kernel_size=5, stride=2, pad=2, num_output=72, shared_name='conv2'))
      relu2.append(self._make_relu_layer(network))

      conv3.append(self._make_conv_layer(network, kernel_size=5, stride=2, pad=2, num_output=96, shared_name='conv3'))
      relu3.append(self._make_relu_layer(network))

      fc1.append(self._make_inner_product_layer(network, num_output=1024, shared_name='fc1'))
      relu4.append(self._make_relu_layer(network))

      fc2.append(self._make_inner_product_layer(network, num_output=1024, shared_name='fc2'))
      relu5.append(self._make_relu_layer(network))
      relu5_split.append(self._make_split_layer(network))
      relu5_split[-1].name = '%s_relu5_split' % inp

      if inp=='query':
        fc1_id = self._make_inner_product_layer(network, num_output=512, shared_name='fc1_id')
        id_split = self._make_split_layer(network)
        id_split.name = 'query_id_split'
      elif inp=='diffid':
        fc1_diff_id = self._make_inner_product_layer(network, num_output=512, shared_name='fc1_id')
        fc1_diff_id.name = 'fc1_diff_id'
        continue # don't care about view from this input.
      elif inp=='target':
        target_fc1_id = self._make_inner_product_layer(network, num_output=512, shared_name='fc1_id')

      fc1_view.append(self._make_inner_product_layer(network, num_output=256, shared_name='fc1_view'))
      fc1_view[-1].name = '%s_fc1_view' % inp

    # Extract target hidden units.
    target_hid = self._make_concat_layer(network)
    target_hid.top.append('target_hid')

    # Extract trans = out_view - ref_view
    trans_view = self._make_sum_layer(network, coeff=[-1,1])
    trans_view_fc1 = self._make_inner_product_layer(network, num_output=256)
    trans_view_fc1_split = self._make_split_layer(network)
    trans_view_fc1_split.name = 'trans_view_fc1_split'

    # Make the encoding connections.
    connections = []
    for idx,inp in enumerate(inputs):
      connections.append((conv1[idx].name, (conv1[idx].top, relu1[idx].bottom, relu1[idx].top, conv2[idx].bottom)))
      connections.append((conv2[idx].name, (conv2[idx].top, relu2[idx].bottom, relu2[idx].top, conv3[idx].bottom)))
      connections.append((conv3[idx].name, (conv3[idx].top, relu3[idx].bottom, relu3[idx].top, fc1[idx].bottom)))
      connections.append((fc1[idx].name, (fc1[idx].top, relu4[idx].bottom, relu4[idx].top, fc2[idx].bottom)))
      connections.append((fc2[idx].name, (fc2[idx].top, relu5[idx].bottom)))
      connections.append((relu5[idx].name, (relu5[idx].top, relu5_split[idx].bottom)))

      if inp != 'diffid': # diffid has no view output.
        connections.append((relu5_split[idx].name, (relu5_split[idx].top, fc1_view[idx].bottom)))
      if inp=='query':
        connections.append((relu5_split[idx].name, (relu5_split[idx].top, fc1_id.bottom)))
        connections.append((fc1_id.name, (fc1_id.top, id_split.bottom)))
      elif inp=='diffid':
        connections.append((relu5_split[idx].name, (relu5_split[idx].top, fc1_diff_id.bottom)))
      elif inp=='target':
        connections.append((relu5_split[idx].name, (relu5_split[idx].top, target_fc1_id.bottom)))
        connections.append((target_fc1_id.name, (target_fc1_id.top, target_hid.bottom)))
        connections.append((fc1_view[idx].name, (fc1_view[idx].top, target_hid.bottom)))


    # Connect view transformation.
    connections.append((fc1_view[0].name, (fc1_view[0].top, trans_view.bottom)))
    connections.append((fc1_view[1].name, (fc1_view[1].top, trans_view.bottom)))
    connections.append((trans_view.name, (trans_view.top, trans_view_fc1.bottom)))
    connections.append((trans_view_fc1.name, (trans_view_fc1.top, trans_view_fc1_split.bottom)))
    for t in range(numstep+1):
      if t < numstep:
        # Action.
        view_split.append(self._make_split_layer(network))
        sum_trans.append(self._make_sum_layer(network))
        relu1_trans.append(self._make_relu_layer(network))
        inc_fc2.append(self._make_inner_product_layer(network, num_output=256, shared_name='inc_fc2'))
        inc_relu2.append(self._make_relu_layer(network))
        inc_fc3.append(self._make_inner_product_layer(network, num_output=256, shared_name='inc_fc3'))
        sum_view.append(self._make_sum_layer(network))
        sum_view_split.append(self._make_split_layer(network))

      ###### Action connections ######
      # current view -> view_split. 
      # Go to both transformation increment and add to next step.
      if t==0: # Take view from query encoder.
        connections.append((fc1_view[2].name, (fc1_view[2].top, view_split[-1].bottom)))
      elif t<numstep:    # Take view from previous time step.
        connections.append((sum_view_split[-2].name, (sum_view_split[-2].top, view_split[-1].bottom)))

      if t<numstep:
        connections.append((view_split[-1].name, (view_split[-1].top, sum_trans[-1].bottom)))
        connections.append((trans_view_fc1_split.name, (trans_view_fc1_split.top, sum_trans[-1].bottom)))
        connections.append((sum_trans[-1].name, (sum_trans[-1].top, relu1_trans[-1].bottom, relu1_trans[-1].top, inc_fc2[-1].bottom)))
        connections.append((inc_fc2[-1].name, (inc_fc2[-1].top, inc_relu2[-1].bottom, inc_relu2[-1].top, inc_fc3[-1].bottom)))
        connections.append((inc_fc3[-1].name, (inc_fc3[-1].top, sum_view[-1].bottom)))
        connections.append((view_split[-1].name, (view_split[-1].top, sum_view[-1].bottom)))
        connections.append((sum_view[-1].name, (sum_view[-1].top, sum_view_split[-1].bottom)))

      # Decoder.
      concat.append(self._make_concat_layer(network))
      concat_split.append(self._make_split_layer(network))
      dec_fc1.append(self._make_inner_product_layer(network, num_output=1024, shared_name='dec_fc1'))
      dec_relu1.append(self._make_relu_layer(network))
      dec_fc2.append(self._make_inner_product_layer(network, num_output=1024, shared_name='dec_fc2'))
      dec_relu2.append(self._make_relu_layer(network))
      dec_relu2_split.append(self._make_split_layer(network))
      # Dec img path.
      dec_img_fc1.append(self._make_inner_product_layer(network, num_output=6144, shared_name='dec_img_fc1'))
      dec_img_relu1.append(self._make_relu_layer(network))
      dec_img_fold.append(self._make_folding_layer(network,96,8,8))
      dec_img_up1.append(self._make_upsampling_layer(network,stride=2))
      dec_img_conv1.append(self._make_conv_layer(network, kernel_size=5, stride=1, pad=2, num_output=72, shared_name='dec_img_conv1'))
      dec_img_relu2.append(self._make_relu_layer(network))
      dec_img_up2.append(self._make_upsampling_layer(network,stride=2))
      dec_img_conv2.append(self._make_conv_layer(network, kernel_size=5, stride=1, pad=2, num_output=64, shared_name='dec_img_conv2'))
      dec_img_relu3.append(self._make_relu_layer(network))
      dec_img_up3.append(self._make_upsampling_layer(network,stride=2))
      dec_img_conv3.append(self._make_conv_layer(network, kernel_size=5, stride=1, pad=2, num_output=3, shared_name='dec_img_conv3'))
      dec_img_conv3[-1].name = 'dec_img_conv3_t%d' % t
      # Dec mask path.
      dec_mask_fc1.append(self._make_inner_product_layer(network, num_output=4608, shared_name='dec_mask_fc1'))
      dec_mask_relu1.append(self._make_relu_layer(network))
      dec_mask_fold.append(self._make_folding_layer(network,72,8,8))
      dec_mask_up1.append(self._make_upsampling_layer(network,stride=2))
      dec_mask_conv1.append(self._make_conv_layer(network, kernel_size=5, stride=1, pad=2, num_output=64, shared_name='dec_mask_conv1'))
      dec_mask_relu2.append(self._make_relu_layer(network))
      dec_mask_up2.append(self._make_upsampling_layer(network,stride=2))
      dec_mask_conv2.append(self._make_conv_layer(network, kernel_size=5, stride=1, pad=2, num_output=32, shared_name='dec_mask_conv2'))
      dec_mask_relu3.append(self._make_relu_layer(network))
      dec_mask_up3.append(self._make_upsampling_layer(network,stride=2))
      dec_mask_conv3.append(self._make_conv_layer(network, kernel_size=5, stride=1, pad=2, num_output=1, shared_name='dec_mask_conv3'))
      dec_mask_conv3[-1].name = 'dec_mask_conv3_t%d' % t

      # dec connections.
      if t==numstep:
        # Diff ID decoder.
        connections.append((fc1_diff_id.name, (fc1_diff_id.top, concat[-1].bottom)))
        connections.append((view_split[0].name, (view_split[0].top, concat[-1].bottom)))
      else:
        # Normal decoder.
        connections.append((id_split.name, (id_split.top, concat[-1].bottom)))
        connections.append((sum_view_split[-1].name, (sum_view_split[-1].top, concat[-1].bottom)))
      connections.append((concat[-1].name, (concat[-1].top, concat_split[-1].bottom)))
      connections.append((concat_split[-1].name, (concat_split[-1].top, dec_fc1[-1].bottom)))
      connections.append((dec_fc1[-1].name, (dec_fc1[-1].top, dec_relu1[-1].bottom, dec_relu1[-1].top, dec_fc2[-1].bottom)))
      connections.append((dec_fc2[-1].name, (dec_fc2[-1].top, dec_relu2[-1].bottom)))
      connections.append((dec_relu2[-1].name, (dec_relu2[-1].top, dec_relu2_split[-1].bottom)))
      # dec image connections.
      connections.append((dec_relu2_split[-1].name, (dec_relu2_split[-1].top, dec_img_fc1[-1].bottom)))
      connections.append((dec_img_fc1[-1].name, (dec_img_fc1[-1].top, dec_img_relu1[-1].bottom, dec_img_relu1[-1].top, dec_img_fold[-1].bottom)))
      connections.append((dec_img_fold[-1].name, (dec_img_fold[-1].top, dec_img_up1[-1].bottom)))
      connections.append((dec_img_up1[-1].name, (dec_img_up1[-1].top, dec_img_conv1[-1].bottom)))
      connections.append((dec_img_conv1[-1].name, (dec_img_conv1[-1].top, dec_img_relu2[-1].bottom, dec_img_relu2[-1].top, dec_img_up2[-1].bottom)))
      connections.append((dec_img_up2[-1].name, (dec_img_up2[-1].top, dec_img_conv2[-1].bottom)))
      connections.append((dec_img_conv2[-1].name, (dec_img_conv2[-1].top, dec_img_relu3[-1].bottom, dec_img_relu3[-1].top, dec_img_up3[-1].bottom)))
      connections.append((dec_img_up3[-1].name, (dec_img_up3[-1].top, dec_img_conv3[-1].bottom)))
      # dec mask connections.
      connections.append((dec_relu2_split[-1].name, (dec_relu2_split[-1].top, dec_mask_fc1[-1].bottom)))
      connections.append((dec_mask_fc1[-1].name, (dec_mask_fc1[-1].top, dec_mask_relu1[-1].bottom, dec_mask_relu1[-1].top, dec_mask_fold[-1].bottom)))
      connections.append((dec_mask_fold[-1].name, (dec_mask_fold[-1].top, dec_mask_up1[-1].bottom)))
      connections.append((dec_mask_up1[-1].name, (dec_mask_up1[-1].top, dec_mask_conv1[-1].bottom)))
      connections.append((dec_mask_conv1[-1].name, (dec_mask_conv1[-1].top, dec_mask_relu2[-1].bottom, dec_mask_relu2[-1].top, dec_mask_up2[-1].bottom)))
      connections.append((dec_mask_up2[-1].name, (dec_mask_up2[-1].top, dec_mask_conv2[-1].bottom)))
      connections.append((dec_mask_conv2[-1].name, (dec_mask_conv2[-1].top, dec_mask_relu3[-1].bottom, dec_mask_relu3[-1].top, dec_mask_up3[-1].bottom)))
      connections.append((dec_mask_up3[-1].name, (dec_mask_up3[-1].top, dec_mask_conv3[-1].bottom)))

    layers = [ fc1_id, fc1_diff_id, id_split, trans_view, trans_view_fc1, trans_view_fc1_split, target_hid, target_fc1_id ]
    layers += conv1
    layers += conv2
    layers += conv3
    layers += relu1
    layers += relu2
    layers += relu3
    layers += fc1
    layers += relu4
    layers += fc2
    layers += relu5
    layers += relu5_split
    layers += fc1_view
    layers += view_split
    layers += relu1_view
    layers += inc_fc2
    layers += inc_relu2
    layers += inc_fc3
    layers += sum_view
    layers += sum_trans
    layers += relu1_trans
    layers += concat
    layers += concat_split
    layers += sum_view_split
    layers += dec_fc1
    layers += dec_relu1
    layers += dec_fc2
    layers += dec_relu2
    layers += dec_relu2_split
    layers += dec_img_fc1
    layers += dec_img_relu1
    layers += dec_img_fold
    layers += dec_img_up1
    layers += dec_img_conv1
    layers += dec_img_relu2
    layers += dec_img_up2
    layers += dec_img_conv2
    layers += dec_img_relu3
    layers += dec_img_up3
    layers += dec_img_conv3
    layers += dec_mask_fc1
    layers += dec_mask_relu1
    layers += dec_mask_fold
    layers += dec_mask_up1
    layers += dec_mask_conv1
    layers += dec_mask_relu2
    layers += dec_mask_up2
    layers += dec_mask_conv2
    layers += dec_mask_relu3
    layers += dec_mask_up3
    layers += dec_mask_conv3

    # Final concat.
    #if numstep > 1:
    dec_img_cat = self._make_concat_layer(network, concat_dim=0)
    dec_img_cat.top.append('dec_img_cat')
    dec_mask_cat = self._make_concat_layer(network, concat_dim=0)
    dec_mask_cat.top.append('dec_mask_cat')
    for l in dec_img_conv3:
      connections.append((l.name, (l.top, dec_img_cat.bottom)))
    for l in dec_mask_conv3:
      connections.append((l.name, (l.top, dec_mask_cat.bottom)))
    layers.append(dec_img_cat)
    layers.append(dec_mask_cat)
    #else:
    #  dec_img_conv3[-1].top.append('dec_img_cat')
    #  dec_mask_conv3[-1].top.append('dec_mask_cat')

    # Final hidden output.
    if numstep > 1:
      hid_out = self._make_concat_layer(network, concat_dim=0)
      hid_out.top.append('hid_out')
      layers.append(hid_out)
      for idx,l in enumerate(concat_split[:-1]):
        l.name = 'hid_out_t%d' % idx
        connections.append((l.name, (l.top, hid_out.bottom)))
    else:
      #concat_split[-1].top.append('hid_out')
      concat_split[0].top.append('hid_out')

    # make connections.
    for connection in connections:
      self._tie(connection, name_generator)

    # Fix up the names based on the connections that were generated.
    for pos, layer in enumerate(layers):
      layer.name += '_%d' % pos

    return network


  def build_network(self, netname, batchsize=1, numstep=4):
    """main method."""

    if netname == 'analogy_rnn':
      network = self._build_rnn_network(batchsize=batchsize,numstep=numstep)
    elif netname == 'analogy_rnn_2layer':
      network = self._build_rnn_network_2layer(batchsize=batchsize,numstep=numstep)
    else:
      print('unknown netname: %s' % netname)
      return

    network_filename = '%s_t%d.prototxt' % (netname,numstep)
    print network
    with open(network_filename, 'w') as network_file:
      network_file.write(text_format.MessageToString(network))
    return Net(network_filename)


if __name__ == '__main__':
  __Network_builder__ = NetworkBuilder()
  #__Network_builder__.build_network(netname='analogy_rnn', batchsize=1,
  #                                  numstep=int(sys.argv[1]))
  __Network_builder__.build_network(netname='analogy_rnn_2layer', batchsize=1,
                                    numstep=int(sys.argv[1]))

