LSTM в Keras: количество параметров отличается между последовательным и функциональным API

1

С помощью последовательного API

Если я создаю LSTM с Sequential API Keras со следующим кодом:

from keras.models import Sequential
from keras.layers import LSTM

model = Sequential()
model.add(LSTM(2, input_dim=3))

затем

model.summary()

возвращает 48 параметров, что в порядке, как указано в этом вопросе.

Изображение 174551

Краткие сведения:

input_dim = 3, output_dim = 2
n_params = 4 * output_dim * (output_dim + input_dim + 1) = 4 * 2 * (2 + 3 + 1) = 48

Функциональный API

Но если я сделаю то же самое с функциональным API со следующим кодом:

from keras.models import Model
from keras.layers import Input
from keras.layers import LSTM

inputs = Input(shape=(3, 1))
lstm = LSTM(2)(inputs)
model = Model(input=inputs, output=lstm)

затем

model.summary()

возвращает 32 параметра.

Изображение 174551

Почему существует такая разница?

Теги:
machine-learning
keras
lstm
recurrent-neural-network

2 ответа

3
Лучший ответ

Разница в том, что когда вы передаете input_dim=x на уровень RNN, включая слои LSTM, это означает, что форма ввода (None, x) т.е. существует различное количество временных меток, где каждый из них является вектором длины x. Однако в примере функционального API вы определяете shape=(3, 1) как форму ввода, а это означает, что есть три временных элемента, каждый из которых имеет одну функцию. Поэтому количество параметров будет: 4 * output_dim * (output_dim + input_dim + 1) = 4 * 2 * (2 + 1 + 1) = 32 что является числом, указанным в сводке модели.

Кроме того, если вы используете Keras 2.xx, вы получите предупреждение в случае использования аргумента input_dim для слоя RNN:

UserWarning: аргументы input_dim и input_length в повторяющихся уровнях устарели. input_shape этого используйте input_shape.

UserWarning: обновите свой LSTM вызов API LSTM(2, input_shape=(None, 3)) 2: LSTM(2, input_shape=(None, 3))

  • 1
    Если в Функциональном API я заменяю input = Input (shape = (3, 1)) на input = Input (shape = (1, 3)) , я получаю 48 параметров, как и ожидалось. Спасибо!
0

Я решил это следующим образом:

Case 1:
m (input) = 3
n (output) = 2

params = 4 * ( (input * output) + (output ^ 2) + output)
       = 4 * (3*2 + 2^2 + 2)
       = 4 * (6 + 4 + 2)
       = 4 * 12
       = 48



Case 2:
m (input) = 1
n (output) = 2

params = 4 * ( (input * output) + (output ^ 2) + output)
       = 4 * (1*2 + 2^2 + 2)
       = 4 * (2 + 4 + 2)
       = 4 * 8
       = 32

Ещё вопросы

Сообщество Overcoder
Наверх
Меню