Как проверить пользовательскую функцию потерь в керасе?

3

Я тренирую CovNet с двумя выходами. Мои учебные образцы выглядят так:

[0, value_a1], [0, value_a2], ...

а также

[value_b1, 0], [value_b2, 0], ....

Я хочу создать свою собственную функцию потери и маскировать пары, содержащие mask_value = 0. У меня есть эта функция, хотя я не уверен, действительно ли она делает то, что я хочу. Итак, я хочу написать несколько тестов.

from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses

def masked_loss_function(y_true, y_pred, mask_value=0):
    '''
    This model has two target values which are independent of each other.
    We mask the output so that only the value that is used for training 
    contributes to the loss.
        mask_value : is the value that is not used for training
    '''
    mask = K.cast(K.not_equal(y_true, mask_value), K.floatx())
    return losses.mean_squared_error(y_true * mask, y_pred * mask)

Хотя, я не знаю, как я могу проверить эту функцию с керасом? Обычно это передается model.compile(). Что-то вроде этих строк:

x = [1, 0]
y = [1, 1]
assert masked_loss_function(x, y, 0) == 0
Теги:
tensorflow
testing
keras
loss-function

2 ответа

3
Лучший ответ

Я думаю, что одним из способов достижения этой цели является использование функции Keras backend. Здесь мы определяем функцию, которая принимает в качестве входных данных два тензора и возвращает в качестве выходного тензора:

from keras import Model
from keras import layers

x = layers.Input(shape=(None,))
y = layers.Input(shape=(None,))
loss_func = K.Function([x, y], [masked_loss_function(x, y, 0)])

И теперь мы можем использовать loss_func для запуска loss_func графика, который мы определили:

assert loss_func([[[1,0]], [[1,1]]]) == [[0]]

Обратите внимание, что функция keras backend, Function, ожидает, что входные и выходные аргументы будут массивом тензоров. Кроме того, x и y принимает партию тензоров, т.е. Массив тензоров с неопределенной формой.

0

Это еще один обходной путь,

x = [1, 0]
y = [1, 1]
F = masked_loss_function(K.variable(x), K.variable(y), K.variable(0))
assert K.eval(F) == 0 

Ещё вопросы

Сообщество Overcoder
Наверх
Меню