Передача> 2 ГБ данных в tf.estimator

1

У меня есть x_train и y_train, каждый из> 2GB. Я хочу обучить модель с использованием API tf.estimator, но получаю ошибки:

ValueError: Cannot create a tensor proto whose content is larger than 2GB

Я передаю данные, используя:

def input_fn(features, labels=None, batch_size=None,
             shuffle=False, repeats=False):
    if labels is not None:
        inputs = (features, labels)
    else:
        inputs = features
    dataset = tf.data.Dataset.from_tensor_slices(inputs)
    if shuffle:
        dataset = dataset.shuffle(shuffle)
    if batch_size:
        dataset = dataset.batch(batch_size)
    if repeats:
        # if False, evaluate after each epoch
        dataset = dataset.repeat(repeats)
    return dataset

train_spec = tf.estimator.TrainSpec(
    lambda : input_fn(x_train, y_train,
                      batch_size=BATCH_SIZE, shuffle=50),
    max_steps=EPOCHS
)

eval_spec = tf.estimator.EvalSpec(lambda : input_fn(x_dev, y_dev))

tf.estimator.train_and_evaluate(model, train_spec, eval_spec)

В документации tf.data упоминается об этой ошибке и предоставляется решение с использованием традиционного API TenforFlow с заполнителями. К сожалению, я не знаю, как это можно перевести в API tf.estimator?

Теги:
tensorflow
tensorflow-datasets
tensorflow-estimator

1 ответ

0

Решение, которое работало для меня, использовало

tf.estimator.inputs.numpy_input_fn(x_train, y_train, num_epochs=EPOCHS,
                                   batch_size=BATCH_SIZE, shuffle=True)

вместо input_fn. Единственная проблема заключается в том, что tf.estimator.inputs.numpy_input_fn вызывает предупреждения об устаревании, поэтому, к сожалению, это также перестанет работать.

Ещё вопросы

Сообщество Overcoder
Наверх
Меню