Как попарно объединить строку и столбец объектов в матрицу объектов в кератах?

1

В keras, используя функциональный API, у меня есть два независимых уровня (тензоры). Первый - это вектор строки списков функций, а другой - вектор столбца списков функций. Для простоты предположим, что они созданы следующим образом:

rows = 5
cols = 10
features = 2

row = Input((1, cols, features))
col = Input((rows, 1, features))

Теперь я хочу, чтобы "слить" эти два слоя таким образом, что результатом является матрица с 5 строк и 10 столбцов ( в основном делают на 5x1 по 1x10 умножения матриц), где каждый элемент этой матрицы является каскадный список функций каждой возможной комбинации вектора строки и столбца. Другими слова, я ищу какое - то MergeLayer, которая будет сочетать мои row и col слои к matrix слою формы (rows, cols, 2*features):

matrix = MergeLayer()([row, col]) # output_shape of matrix shall be (rows, cols, 2*features)

Пример для cols = rows = 2:

row = [[[1,2]], [[3,4]]]
col = [[[5,6],
        [7,8]]]

matrix = [[[1,2,5,6], [3,4,5,6]],
          [[1,2,7,8], [3,4,7,8]]]

Я предполагаю, что решение (если возможно вообще) будет каким-то образом использовать слой Dot и, возможно, некоторые Reshape и/или Permute, но я не могу понять это.

Теги:
keras

1 ответ

1
Лучший ответ

Вы можете повторять элементы, а затем конкатенировать.

from keras.layers import Input, Lambda, Concatenate
from keras.models import Model
import keras.backend as K

rows = 2
cols = 2
features = 2

row = Input((1, cols, features))
col = Input((rows, 1, features))

row_repeated = Lambda(lambda x: K.repeat_elements(x, rows, axis=1))(row)
col_repeated = Lambda(lambda x: K.repeat_elements(x, cols, axis=2))(col)
out = Concatenate()([row_repeated, col_repeated])

model = Model(inputs=[row,col], outputs=out)
model.summary()

Эксперимент:

import numpy as np

x = np.array([1,2,3,4]).reshape((1, 1, 2, 2))
y = np.array([5,6,7,8]).reshape((1, 2, 1, 2))
model.predict([x, y])

#array([[[[1., 2., 5., 6.],
#         [3., 4., 5., 6.]],
#
#        [[1., 2., 7., 8.],
#         [3., 4., 7., 8.]]]], dtype=float32)

Ещё вопросы

Сообщество Overcoder
Наверх
Меню