endaaman.com

2019-07-22

Tips

PyTorchのnn.CrossEntropyLossで急にlossが極大値に振れてしまう

多クラスのsegmentaionで起きやすいかもしれない

lossが突然跳ね上がる現象

セマンティックセグメンテーションのネットワークを訓練していると、学習がある程度進行すると減少していたlossが突然跳ね上がるように大きくなる現象が起きてしまった。

状況など

基本的なセマンティックセグメンテーションの訓練にnn.CrossEntropyLossを使う。これを使うときはモデルの出力には活性関数を通さずに(softmaxやsigmoidなど)せずに、logitのまま出力する。

そのときの学習のコードは

criterion = nn.CrossEntropyLoss()

for inputs, labels in enumerate(data_loader):
    outputs = model(inputs).to(device)
    loss = criterion(outputs, labels)

のようになる。このまま学習すると、softmaxに入れるモデル出力のlogitが大きくなりすぎて、softmaxに通したとき0になる要素が現れてしまい、logに通したときにInfに振れてしまう現象が起こるようである。

修正案

前提として、nn.CrossEntropyLoss()softmax → cross-entrooy (NLL → log) を一挙に行う関数として提供されており、モデルのケツに活性化関数をつけず(0〜1に押しこめず)にlogitのまま入力されることを期待しているが、これだと扱いにくい。なのでnn.CrossEntropyLoss()は使わずに自分で実装することにする。

実装としては、まずモデルのケツで活性化関数をくぐらせるようにし、出力を常に0〜1で得られるようにする。

    def forward(self, x):
        ### ...snip
        if self.num_classes > 1:
            return torch.softmax(out, dim=1)
        else:
            return torch.sigmoid(out)

多クラスならsoftmax、1クラスならsimogidを使うのが普通だ。

そしてロス関数はNLL → 微小値付与 → log と行うようにする。実装としては、

class CrossEntropyLoss2d(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_fn = nn.NLLLoss()

    # `x`と`y`はそれぞれモデル出力の値とGTの値
    def forward(self, x, y, eps=1e-24):
        # (BATCH, C, H, W) から (C, BATCH * H * W) に並べ直す
        x = x.contiguous().view(-1, NUM_CLASSES)
        # 微小値を足してからlogる※1
        x = x.log(x + eps)

        y = y.contiguous().view(-1, NUM_CLASSES)
        _, y = torch.max(y, -1)
        return self.loss_fn(x, y)

※1微小値を全要素に足しているが、max(p, eps)でも良い。ただ単純に足すだけのほうがおそらく速い早い(実験してないので分からない)。ともかく0ぴったりにならないようにする。

こんな感じ。NLLLossはsoftmaxのlogをしないバージョンで、 (Batch, Channel) 形式の入力を期待するので、transpose→flattenして変形する。

これで学習が安定するようになった。

ひとこと


©2024 endaaman.com