У меня есть два тензора, и я должен повторить первый, чтобы взять только тот элемент, который находится внутри другого тензора. В t2
есть только один элемент, который также находится внутри t1
. Вот пример
t1 = tf.where(values > 0) # I get some indices example [6, 0], [3, 0]
t2 = tf.where(values2 > 0) # I get [4, 0], [3, 0]
t3 = .... # [3, 0]
Я попытался оценить и перебирать их, используя .eval()
и проверяется, если элемент t2
в t1
с помощью оператора in
, но не работает. Есть ли функция от TensorFlow, которая может это сделать?
редактировать
for index in xrange(max_indices):
indices = tf.where(tf.equal(values, (index + 1))).eval() # indices: [[1 0]\n [4 0]\n [9 0]]
cent_indices = tf.where(centers > 0).eval() # cent_indices: [[6 0]\n [9 0]]
indices_list.append(indices)
for cent in cent_indices:
if cent in indices:
centers_list.append(cent)
break
Первый итерационный cent
имеет значение [6 0]
но он входит в условие if
.
ответ
for index in xrange(max_indices):
indices = tf.where(tf.equal(values, (index + 1))).eval()
cent_indices = tf.where(centers > 0).eval()
indices_list.append(indices)
for cent in cent_indices:
# batch_item is an iterator from an outer loop
if values[batch_item, cent[0]].eval() == (index + 1):
centers_list.append(tf.constant(cent))
break
Решение связано с моей задачей, но если вы ищете решение в тензоре 1D, я предлагаю посмотреть на tf.sets.set_intersection
Это то, что вы хотели? Я использовал только эти два теста.
x = tf.constant([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 1]])
y = tf.constant([[1, 2, 3, 4, 3, 6], [1, 2, 3, 4, 5, 1]])
# x = tf.constant([[1, 2], [4, 5], [7, 7]])
# y = tf.constant([[7, 7], [3, 5]])
def match(xiterations, yiterations, yvalues, xvalues ):
for i in range(xiterations):
for j in range(yiterations):
if (np.array_equal(yvalues[j], xvalues[i])):
print( yvalues[j])
with tf.Session() as sess:
xindex = tf.where( x > 4 )
yindex = tf.where( y > 4 )
xvalues = xindex.eval()
yvalues = yindex.eval()
xiterations = tf.shape(xvalues)[0].eval()
yiterations = tf.shape(yvalues)[0].eval()
print(tf.shape(xvalues)[0].eval())
print(tf.shape(yvalues)[0].eval())
if tf.shape(xvalues)[0].eval() >= tf.shape(yvalues)[0].eval():
match( xiterations, yiterations, yvalues, xvalues)
else:
match( yiterations, xiterations, xvalues, yvalues)