Я применил алгоритм большинства голосов (подсчет прогноза разных классификаторов) на Tensorflow 1.10, и он очень медленный (занимает более 3 часов для 10 классификаторов) для прогнозирования набора данных с размером 1000 (MNIST). Основываясь на моей догадке, это связано с большим вызовом session.run() в моем коде, но как я могу его оптимизировать?
def majority_voting(session, x, y):
votes = []
for i in range(number_of_ensemble_modules):
# run the training
feature_extractor = iterators[i][3]
input, label = feature_extractor(x, y)
transformed_x = session.run(input)
ensemble_prediction = nn_models[0][i][0][3]
prediction = session.run(ensemble_prediction, feed_dict={X: transformed_x, Y: y})
votes.append(prediction[0])
nearest_k_y, idx, vote = tf.unique_with_counts(tf.convert_to_tensor(votes, tf.int64))
majority = tf.argmax(vote)
predict_res = tf.gather(nearest_k_y, majority)
return predict_res
def calculate_ensemble_accuracy():
accuracy = 0
for j in range(voting_iterations):
accuracy += 0
(features, labels) = session.run(next_element)
vote = majority_voting(session, features, labels)
correct_label = session.run(tf.argmax(labels, axis=1))
if vote == correct_label[0]:
accuracy += 1
return accuracy
Некоторые советы, которые могут решить вашу проблему.
1- Удалите функцию перед созданием tensorflow graph
. Например, если вы создаете TfIDF
функции TfIDF
, вы можете сделать это на этапе препроцесса и сохранить numpy для ввода графика.
input, label = feature_extractor(x, y)
2- Удалите ненужный session.run()
. Например, когда вы вызываете Оптимизатор, он автоматически вызывает x_transformed.
transformed_x = session.run(input)
3- Используйте tf.data
(API-интерфейс Dataset API) лучше. sess.run(next_element)
необходимости вызывать sess.run(next_element)
. Потому что next_element
является частью вашего графика.