Медленное выполнение Tensorflow при голосовании большинства

1

Я применил алгоритм большинства голосов (подсчет прогноза разных классификаторов) на Tensorflow 1.10, и он очень медленный (занимает более 3 часов для 10 классификаторов) для прогнозирования набора данных с размером 1000 (MNIST). Основываясь на моей догадке, это связано с большим вызовом session.run() в моем коде, но как я могу его оптимизировать?

def majority_voting(session, x, y):
    votes = []
    for i in range(number_of_ensemble_modules):
        # run the training
        feature_extractor = iterators[i][3]
        input, label = feature_extractor(x, y)
        transformed_x = session.run(input)
        ensemble_prediction = nn_models[0][i][0][3]
        prediction = session.run(ensemble_prediction, feed_dict={X: transformed_x, Y: y})
        votes.append(prediction[0])
    nearest_k_y, idx, vote = tf.unique_with_counts(tf.convert_to_tensor(votes, tf.int64))
    majority = tf.argmax(vote)
    predict_res = tf.gather(nearest_k_y, majority)
    return predict_res


def calculate_ensemble_accuracy():
    accuracy = 0
    for j in range(voting_iterations):
        accuracy += 0
        (features, labels) = session.run(next_element)
        vote = majority_voting(session, features, labels)
        correct_label = session.run(tf.argmax(labels, axis=1))
        if vote == correct_label[0]:
            accuracy += 1
    return accuracy
  • 0
    Под «10 предикторами» я понимаю, что вы имеете в виду 10 классификаторов, верно?
  • 0
    Да, я имею в виду классификатор (или предиктор)
Показать ещё 2 комментария
Теги:
tensorflow
machine-learning

1 ответ

0

Некоторые советы, которые могут решить вашу проблему.

1- Удалите функцию перед созданием tensorflow graph. Например, если вы создаете TfIDF функции TfIDF, вы можете сделать это на этапе препроцесса и сохранить numpy для ввода графика.

input, label = feature_extractor(x, y)

2- Удалите ненужный session.run(). Например, когда вы вызываете Оптимизатор, он автоматически вызывает x_transformed.

transformed_x = session.run(input)

3- Используйте tf.data (API-интерфейс Dataset API) лучше. sess.run(next_element) необходимости вызывать sess.run(next_element). Потому что next_element является частью вашего графика.

  • 0
    Спасибо за вашу помощь и предложения, трудоемкая часть находится в части голосования, которая должна собрать все прогнозы (вызывая session.run ()), и я должен оптимизировать эту часть.

Ещё вопросы

Сообщество Overcoder
Наверх
Меню