Как использовать индексы из tf.nn.top_k с tf.gather_nd?

2

Я пытаюсь использовать индексы, возвращенные из tf.nn.top_k, чтобы извлечь значения из второго тензора.

Я пытался использовать индексирование типа numpy, а также напрямую tf.gather_nd, но заметил, что индексация неверна.

#  temp_attention_weights of shape [I, B, 1, J]
top_values, top_indices = tf.nn.top_k(temp_attention_weights, k=top_k)

# top_indices of shape [I, B, 1, top_k], base_encoder_transformed of shape [I, B, 1, J]

# I now want to extract from base_encoder_transformed top_indices
base_encoder_transformed = tf.gather_nd(base_encoder_transformed, indices=top_indices)  

# base_encoder_transformed should be of shape [I, B, 1, top_k]

Я заметил, что top_indices имеет неправильный формат, но я не могу преобразовать его для использования в tf.gather_nd, где самое внутреннее измерение используется для индексации каждого соответствующего элемента из base_encoder_transformed. Кто-нибудь знает способ получить top_indices в правильном формате?

Теги:
tensorflow

2 ответа

3
Лучший ответ

top_indices будет индексировать только по последней оси, вам нужно добавить индексы и для остальных осей. Это легко с tf.meshgrid:

import tensorflow as tf

# Example input data
I = 4
B = 3
J = 5
top_k = 2
x = tf.reshape(tf.range(I * B * J), (I, B, 1, J)) % 7
# Top K
top_values, top_indices = tf.nn.top_k(x, k=top_k)
# Make indices for the rest of axes
ii, jj, kk, _ = tf.meshgrid(
    tf.range(I),
    tf.range(B),
    tf.range(1),
    tf.range(top_k),
    indexing='ij')
# Stack complete index
index = tf.stack([ii, jj, kk, top_indices], axis=-1)
# Get the same values again
top_values_2 = tf.gather_nd(x, index)
# Test
with tf.Session() as sess:
    v1, v2 = sess.run([top_values, top_values_2])
    print((v1 == v2).all())
    # True
0

Я попробовал другой способ решить эту проблему, используя tf.unstack и tf.stack, для некоторых случаев, когда я не могу получить информацию о форме.

# just for simple
base_encoder_transformed = tf.squeeze(base_encoder_transformed,axis=2) # shape [I,B,J]
top_indices = tf.squeeze(top_indices,axis=2) # shape [I,B,top_k]

# base_encoder_transformed should be of shape [I, B, 1, top_k]
baseI = tf.unstack(base_encoder_transformed) # ( [(B,J)]*I )
indI = tf.unstack(top_indices) # ( [B,top_k] * I)
output = []
for i in range(len(baseI)):
    baseiB = tf.unstack(baseI[i]) # ( [(J,)]*B )
    indiB = tf.unstack(indI[i]) # ( [(top_k,)]*B)
    outputB = []
    for b in range(len(baseiB)):
        outputB.append(tf.gather_nd(baseiB[b],tf.expand_dims(indiB[b],1))) #([top_k,])
    output.append(tf.stack(outputB,axis=0)) # ( [B, top_k])


base_encoder_transformed = tf.stack(output) #([I,B,top_k])
base_encoder_transformed = tf.expand_dims(base_encoder_transformed,axis=2) #([I,B,1,top_k])

Ещё вопросы

Сообщество Overcoder
Наверх
Меню