Как увеличить тензор (повторяющееся значение) в тензорном потоке?

Я новичок в TensorFlow. Я пытаюсь реализовать извлечение global_context в этой статье https://arxiv.org/abs/1506.04579, который на самом деле представляет собой средний пул по всей карте объектов, а затем дублирует карту объектов 1x1 до исходного размера. Иллюстрация ниже

введите описание изображения здесь

В частности, ожидаемая операция следующая. ввод: [N, 1, 1, C] тензор, где N — размер пакета, а C — номер канала вывод: [N, H, W, C] тензор, где H, W — высота и ширина оригинала карта функций, и все выходные значения H * W такие же, как и входные данные 1x1.

Например,

    [[1, 1, 1]
1 -> [1, 1, 1]
     [1, 1, 1]]

Я понятия не имею, как это сделать с помощью TensorFlow. Для tf.image.resize_images требуется 3 канала, а tf.pad не может дополнять постоянное значение, отличное от нуля.


person jackykuo    schedule 10.03.2017    source источник


Ответы (1)


tf.tile может вам помочь

x = tf.constant([[1, 2, 3]]) # shape (1, 3)
y = tf.tile(x, [3, 1]) # shape (3, 3)
y_ = tf.tile(x, [3, 2]) # shape (3, 6)

with tf.Session() as sess:
    a, b, c = sess.run([x, y, y_])

>>>a
array([[1, 2, 3]], dtype=int32)
>>>b
array([[1, 2, 3],
       [1, 2, 3],
       [1, 2, 3]], dtype=int32)
>>>c
array([[1, 2, 3, 1, 2, 3],
       [1, 2, 3, 1, 2, 3],
       [1, 2, 3, 1, 2, 3]], dtype=int32)

tf.tile(input, multiples, name=None)
multiples означает, сколько раз вы хотите повторить по этой оси
в y повторить ось0 3 раза
в y_ повторить ось0 3 раза, а ось1 2 раза

вам может понадобиться сначала использовать tf.expand_dim

Да, он принимает динамическую форму

x = tf.placeholder(dtype=tf.float32, shape=[None, 4])
x_shape = tf.shape(x)
y = tf.tile(x, [3 * x_shape[0], 1])

with tf.Session() as sess:
    x_ = np.array([[1, 2, 3, 4]])
    a = sess.run(y, feed_dict={x:x_})
>>>a
array([[ 1.,  2.,  3.,  4.],
       [ 1.,  2.,  3.,  4.],
       [ 1.,  2.,  3.,  4.]], dtype=float32)
person xxi    schedule 10.03.2017
comment
Спасибо, я должен был найти это раньше. Можно ли использовать tf.tile с динамической формой тензора? например, tf.tile(input, [1, ori.get_shape()[1], ori.get_shape()[2], 1]). Я не хочу фиксировать скорость увеличения в сети. - person jackykuo; 10.03.2017
comment
Вторая строка в вашем первом блоке кода должна иметь комментарий # shape (3,3), если я не ошибаюсь - person jvrsgsty; 17.01.2019