ignore_index=-100
とか配列のインデックスから外れるような値を指定した場合に torch.nn.functional.cross_entropy(input, target, ..., ignore_index=-100, label_smoothing=1)
の正解値 target
にそのインデックスが入り込むと Out of bounds error が発生した (後述するがweightも指定してて、こっちのが重要)。
ignore_index
を適切に処理できているようには見えず、torch.nn.functional.cross_entropy()
, torch.nn.CrossEntropyLoss
で、label_smoothing > 0
に設定した場合、ignore_index
が機能していない疑惑がある。ので確認してみる。
結論 (TL;DR)
torch.nn.functional.cross_entropy(Input, Target, ...)
において、
weight
(size=[tex:(C,)]
)を指定、ignore_index
をクラスのサイズ以外のインデックスに指定して (e.g. -100, -255)、label_smoothing > 0
,reduction='mean'
、
と指定したとき(超限定的)、内部で呼び出されるtorch._C._nn.cross_entropy_loss
で Target
内のignore_index
の除外がされることなくweight[Target]
と範囲外の参照が行われて(多分)、Out of boundsする。
ignore_index=-100
が想定されているのにエラーが出るってことはバグなんだろう。
既出なのかわからんがIssueに投げた(一応調べた)。
https://github.com/pytorch/pytorch/issues/91383
またignore_index=-100
で無視してほしい場所にも、label_smoothing/n_classes
分?の最低限なノイズが入る様子。対してignore_index=2
のように既存のクラス に含まれるインデックスを指定すると、しっかりとignoreされ(smoothingのノイズなしで)lossが0になる。
⇒ ignore_index
が ignore_index
かそれ以外(ignore_index
, ignore_index
)かでも変わるみたい?で一貫性がないので多分バグ (Issueにまだ投げていない)。
詳細・検証 (再現)
以下、詳細(長い)
環境
- python=3.7.11,
- pytorch=1.10,
Input, Targetの準備, 素のCrossEntropyLoss
import torch import numpy as np import sys print("python={}".format(sys.version_info) ) print("torch.__version__={}".format(torch.__version__)) # python=sys.version_info(major=3, minor=7, micro=11, releaselevel='final', serial=0) # torch.__version__=1.10.0
クラスを3つ()にして3つのサンプル()で、すべてをクラス0だと予測したとする。
# assuming 3 samples (N=3), three-casss (C=3). N=3 # samples (mini-batch size) C=3 # 3 calss _ones = torch.ones(N,1,H,W) _zeros = torch.zeros(N,1,H,W) X = torch.cat([_ones, _zeros, _zeros], dim=1) # respond all sample as class 0 X, X.shape # torch.Size([3, 3]))
Out:
(tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]]), torch.Size([3, 3]))
正解値Yはサンプル0, 1, 2でそれぞれクラス0,1,2として、サンプル0のみ正解となる。
Y = torch.tensor([0,1,2]).type(torch.LongTensor) Y # tensor([0, 1, 2])
損失を計算してみると、 TPで 0.5514, FNで1.5514 になる様子。
loss_mean = torch.nn.functional.cross_entropy(X,Y, weight=None, ignore_index=-100, label_smoothing=0) loss = torch.nn.functional.cross_entropy(X,Y, weight=None, ignore_index=-100, label_smoothing=0, reduction='none') print("Loss is {}, ".format(loss_mean)) loss
Out:
Loss is 1.2181113958358765, tensor([0.5514, 1.5514, 1.5514])
1→1 (TP)でloss=0.5514
(サンプル0), 1→0 (FN)でloss=1.5514
(サンプル1,2) となる様子。
ignore_indexを使ってみる (ignore_index=-100
, label_smoothing=0
)
サンプル2 のみignoreされるはず。
ignore_index = -100 X_ignore = X.clone() X_ignore[2,:] = ignore_index X_ignore, X_ignore.shape Y_ignore = Y.clone() Y_ignore[2,] = ignore_index Y_ignore, Y_ignore.shape
Out:
# X_ignore (tensor([[ 1., 0., 0.], [ 1., 0., 0.], [-100., -100., -100.]]), torch.Size([3, 3])) # Y_ignore (tensor([ 0, 1, -100]), torch.Size([3]))
lossがignoreされたサンプル2のみ0.000
になっている(正常にignoreされている)。
loss_mean = torch.nn.functional.cross_entropy(X_ignore,Y_ignore, weight=None, ignore_index=-100, label_smoothing=0) loss = torch.nn.functional.cross_entropy(X_ignore,Y_ignore, weight=None, ignore_index=ignore_index, label_smoothing=0, reduction='none') print("Loss is {}.".format(loss_mean)) loss
Out:
Loss is 1.05144464969635. tensor([0.5514, 1.5514, 0.0000])
gnore_index=-100
, label_smoothing=0.1
ignoreしてsmoothingしたときの損失を計算してみる。
label_smoothingでは全体的に [tex: (\alpha\mathrm{label_smoothing})/N\mathrm{class}]のノイズが付与されるようなので、全体的に損失がの増減が予想される。(Y = Y * (1 -label_smoothing) +label_smoothing/n_classses)
loss_mean = torch.nn.functional.cross_entropy(X_ignore, Y_ignore, weight=None, ignore_index=-100, label_smoothing=0.1) loss = torch.nn.functional.cross_entropy(X_ignore,Y_ignore, weight=None, ignore_index=ignore_index, label_smoothing=0.1, reduction='none') print("Loss is {}.".format(loss_mean)) loss
Out:
Loss is 1.064128041267395. tensor([0.6181, 1.5181, 0.1099])
サンプル1の 1⇒1 (TP)でloss=0.6181
, サンプル2の 1⇒0 (FN)で loss=1.5181
, サンプル3の 1⇒0 (FN, ignored)でloss=0.1099
となっている。
ignore しろって言ってもノイズは適用されるらしい。
ひとまずsmoothingしててもignore_index
は機能している様子。
ignore_index以外の値が含まれていた時
ignore_index=-100
と指定したのに-255が入っていたら、もちろんIndexErrorになる(Out of bounds)。
_Y_ignore255 = Y_ignore.clone() _Y_ignore255[Y_ignore==-100] = -255 # ignore_index: -100 -> -255 # -255 is in Input and Target, but ignore_index=-100 loss = torch.nn.functional.cross_entropy(X_ignore,_Y_ignore255, weight=None, ignore_index=-100, label_smoothing=0.1, reduction='none')
Out:
--------------------------------------------------------------------------- IndexError Traceback (most recent call last) /tmp/ipykernel_2483606/1826466474.py in <module> 2 _Y_ignore255[Y_ignore==-100] = -255 # ignore_index: -100 -> -255 3 # -255 is in Input and Target, but ignore_index=-100 ----> 4 loss_noneReduction = torch.nn.functional.cross_entropy(X_ignore,_Y_ignore255, weight=None, ignore_index=-100, label_smoothing=0.1, reduction='none') /Anaconda/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing) 2844 if size_average is not None or reduce is not None: 2845 reduction = _Reduction.legacy_get_string(size_average, reduce) -> 2846 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) 2847 2848 IndexError: Target -255 is out of bounds.
weight
も使ってみる (weight=[1,1,1]
, ignore_index=-100
, label_smoothing>0
)
weights を各クラスに1として、フラットな重みを設定する(意味ないね)。
weights = torch.ones(C) weights, weights.shape
Out
(tensor([1., 1., 1.]), torch.Size([3]))
weight=[1,1,1]
, ignore_index=-100
, label_smoothing>0
, reduction='none'
loss = torch.nn.functional.cross_entropy(X_ignore,Y_ignore, weight=weights, ignore_index=ignore_index, label_smoothing=0.1, reduction='none') loss
Out:
tensor([0.6181, 1.5181, 0.1099])
weight=[1,1,1]
, ignore_index=-100
, label_smoothing>0
, **reduction='mean'**
loss_mean = torch.nn.functional.cross_entropy(X_ignore,Y_ignore, weight=weights, ignore_index=-100, label_smoothing=0.1, reduction='mean')
Out:
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) /tmp/ipykernel_2483606/2859958629.py in <module> ----> 1 loss_mean = torch.nn.functional.cross_entropy(X_ignore,Y_ignore, weight=weights, ignore_index=-100, label_smoothing=0.1, reduction='mean') /Anaconda/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing) 2844 if size_average is not None or reduce is not None: 2845 reduction = _Reduction.legacy_get_string(size_average, reduce) -> 2846 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) 2847 2848 RuntimeError: index -100 is out of bounds for dimension 0 with size 3
RuntimeError: index -100 is out of bounds for dimension 0 with size 3
が出た。weightに指定したtensorのインデックス外を参照している様子。weight
つけてignore_index
とlabel_smoothing
使ってreduction=’mean’
としたときのみのバグっぽい(reduction=’sum’
, or ‘none’
ではエラーが出なかった)。
デフォルト値が ignore_index=-100
なので、クラスのインデックス $[0, C)$を超える値と負値も想定しているはずなので、単純なバグっぽい。
torch 1.11.0ではどうなっているかは知らんが1.10.0もまだまだ現役なので治っててほしい(label_smoothingが効果があるかは微妙だが)。
ignore_index
がかそれ以外かでのlabel_smooth
の挙動
ignore_index=-255
,+100
で外 のときでも、サンプル2のloss=1.0986
なのでノイズが入っている。
_Y_ignore255 = Y_ignore.clone() _Y_ignore255[Y_ignore==-100] = -255 # ignore_index: -100 -> -255 # -255 is in Input and Target, but ignore_index=-100 loss = torch.nn.functional.cross_entropy(X_ignore,_Y_ignore255, weight=None, ignore_index=-100, label_smoothing=0.1, reduction='none') loss # tensor([0.6181, 1.5181, 1.0986])
_Y_ignore_100 = Y_ignore.clone() _Y_ignore_100[Y_ignore==-100] = 100 # ignore_index: -100 -> 100 not in [0, N_class=3) _loss = torch.nn.functional.cross_entropy(X_ignore,_Y_ignore_2, weight=torch.ones(C), ignore_index=100, label_smoothing=0.1, reduction='none') _loss # tensor([0.6181, 1.5181, 1.0986])
そうしようかと思いきや、ignore_index=2
in のとき、ignore_index
のサンプル2の損失が0なのでsmoothingのノイズは無し。
_Y_ignore_2 = Y_ignore.clone() _Y_ignore_2[Y_ignore==-100] = 2 # ignore_index: -100 -> 0 in [0, N_class=3) _loss = torch.nn.functional.cross_entropy(X_ignore,_Y_ignore_2, weight=torch.ones(C), ignore_index=2, label_smoothing=0.1, reduction='none') _loss # tensor([0.6181, 1.5181, 0.0000])
さすがに一貫性がないので、これもバグだと思う。