lossが突然跳ね上がる現象
セマンティックセグメンテーションのネットワークを訓練していると、学習がある程度進行すると減少していたlossが突然跳ね上がるように大きくなる現象が起きてしまった。
状況など
基本的なセマンティックセグメンテーションの訓練にnn.CrossEntropyLoss
を使う。これを使うときはモデルの出力には活性関数を通さずに(softmaxやsigmoidなど)せずに、logitのまま出力する。
そのときの学習のコードは
criterion = nn.CrossEntropyLoss()
for inputs, labels in enumerate(data_loader):
outputs = model(inputs).to(device)
loss = criterion(outputs, labels)
のようになる。このまま学習すると、softmaxに入れるモデル出力のlogitが大きくなりすぎて、softmaxに通したとき0になる要素が現れてしまい、logに通したときにInfに振れてしまう現象が起こるようである。
修正案
前提として、nn.CrossEntropyLoss()
は softmax → cross-entrooy (NLL → log)
を一挙に行う関数として提供されており、モデルのケツに活性化関数をつけず(0〜1に押しこめず)にlogitのまま入力されることを期待しているが、これだと扱いにくい。なのでnn.CrossEntropyLoss()
は使わずに自分で実装することにする。
実装としては、まずモデルのケツで活性化関数をくぐらせるようにし、出力を常に0〜1で得られるようにする。
def forward(self, x):
### ...snip
if self.num_classes > 1:
return torch.softmax(out, dim=1)
else:
return torch.sigmoid(out)
多クラスならsoftmax、1クラスならsimogidを使うのが普通だ。
そしてロス関数はNLL → 微小値付与 → log
と行うようにする。実装としては、
class CrossEntropyLoss2d(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.NLLLoss()
# `x`と`y`はそれぞれモデル出力の値とGTの値
def forward(self, x, y, eps=1e-24):
# (BATCH, C, H, W) から (C, BATCH * H * W) に並べ直す
x = x.contiguous().view(-1, NUM_CLASSES)
# 微小値を足してからlogる※1
x = x.log(x + eps)
y = y.contiguous().view(-1, NUM_CLASSES)
_, y = torch.max(y, -1)
return self.loss_fn(x, y)
※1微小値を全要素に足しているが、max(p, eps)
でも良い。ただ単純に足すだけのほうがおそらく速い早い(実験してないので分からない)。ともかく0ぴったりにならないようにする。
こんな感じ。NLLLoss
はsoftmaxのlogをしないバージョンで、 (Batch, Channel)
形式の入力を期待するので、transpose→flattenして変形する。
これで学習が安定するようになった。
ひとこと
みんな大好きゼロから作るDeep Learningにも書かれている話だけど、実際に起こるとは思わなかった。PyTorchのデフォルトのF.log_softmax()
も間にこの処理が挟まってないし、それに対応するオプションもないので混乱した。