PyTorch: потери остаются постоянными

1

Я написал код в PyTorch с моей собственной функцией потерь focal_loss_fixed. Но мое значение потери остается фиксированным после каждой эпохи. Похоже, что веса не обновляются. Вот мой фрагмент кода:

optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=0.0005)


for epoch in T(range(20)):
    net.train()
    epoch_loss = 0
    for n in range(len(x_train)//batch_size):
        (imgs, true_masks) = data_gen_small(x_train, y_train, iter_num=n, batch_size=batch_size)
        temp = []
        for tt in true_masks:
            temp.append(tt.reshape(128, 128, 1))
        true_masks = np.copy(np.array(temp))
        del temp
        imgs = np.swapaxes(imgs, 1,3)
        imgs = torch.from_numpy(imgs).float().cuda()
        true_masks = torch.from_numpy(true_masks).float().cuda()
        masks_pred = net(imgs)
        masks_probs = F.sigmoid(masks_pred)
        masks_probs_flat = masks_probs.view(-1)
        true_masks_flat = true_masks.view(-1)
        print((focal_loss_fixed(tf.convert_to_tensor(true_masks_flat.data.cpu().numpy()), tf.convert_to_tensor(masks_probs_flat.data.cpu().numpy()))))
        loss = torch.from_numpy(np.array(focal_loss_fixed(tf.convert_to_tensor(true_masks_flat.data.cpu().numpy()), tf.convert_to_tensor(masks_probs_flat.data.cpu().numpy())))).float().cuda()
        loss = Variable(loss.data, requires_grad=True)
        epoch_loss *= (n/(n+1))
        epoch_loss += loss.item()*(1/(n+1))
        print('Step: {0:.2f}% --- loss: {1:.6f}'.format(n * batch_size* 100.0 / len(x_train), epoch_loss), end='\r')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('Epoch finished ! Loss: {}'.format(epoch_loss))

И это моя функция focal_loss_fixed:

def focal_loss_fixed(true_data, pred_data):
    gamma=2.
    alpha=.25
    eps = 1e-7
    # print(type(y_true), type(y_pred))
    pred_data = K.clip(pred_data,eps,1-eps)
    pt_1 = tf.where(tf.equal(true_data, 1), pred_data, tf.ones_like(pred_data))
    pt_0 = tf.where(tf.equal(true_data, 0), pred_data, tf.zeros_like(pred_data))
    with tf.Session() as sess:
        return sess.run(-K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0)))

После каждой эпохи значение потерь остается постоянным (5589.60328). Что с этим не так?

Теги:
machine-learning
pytorch
loss-function

2 ответа

1

При вычислении потерь вы вызываете focal_loss_fixed() который использует TensorFlow для вычисления значения потерь. focal_loss_fixed() создает график и запускает его в сеансе, чтобы получить значение, и к этому моменту PyTorch не имеет понятия о последовательности операций, которые привели к потере, потому что они были вычислены с помощью TensorFlow. Вероятно, тогда все, что PyTorch видит в loss является константой, как если бы вы написали

loss = 3

Таким образом, градиент будет равен нулю, и параметры никогда не будут обновляться. Я предлагаю вам переписать функцию потерь с помощью операций PyTorch, чтобы можно было вычислить градиент по отношению к его входам.

0

Я думаю, что проблема кроется в вашем тяжелом распуте.

По сути, вы не уменьшаете вес на x, а скорее умножаете веса на x, а это означает, что вы мгновенно делаете очень маленькие приращения, что приводит к (по-видимому) функции потери пламени.

Более подробное объяснение этого можно найти в дискуссионном форуме PyTorch (например, здесь или здесь.
К сожалению, источник только для SGD также не говорит вам о его реализации. Простое установление большего значения должно привести к улучшению обновлений. Вы можете начать, оставив его полностью, а затем итеративно уменьшая его (от 1) до получения достойных результатов.

  • 0
    Dennlinger, как я могу изменить weight decay в этом коде? Я новичок в PyTorch и понятия не имею.
  • 0
    @ tahsin314 Снижение веса находится в переменной оптимизатора (SGD). Сделайте его еще меньше или постарайтесь его не использовать.
Показать ещё 3 комментария

Ещё вопросы

Сообщество Overcoder
Наверх
Меню