MOB-LOG

モブおじの記録 (Programming, 統計・機械学習)

torch.nn.functional.cross_entropyで、weight指定してlabel_smoothing>0 のときにOut of bounds エラー(ignore_index が機能していない疑惑)

ignore_index=-100 とか配列のインデックスから外れるような値を指定した場合に torch.nn.functional.cross_entropy(input, target, ..., ignore_index=-100, label_smoothing=1) の正解値 target にそのインデックスが入り込むと Out of bounds error が発生した (後述するがweightも指定してて、こっちのが重要)。

ignore_index=-255, label_smoothing=1, の時の torch.nn.CrossEntropyLoss.foward() のRuntimeError: Out of bounds

ignore_indexを適切に処理できているようには見えず、torch.nn.functional.cross_entropy(), torch.nn.CrossEntropyLossで、label_smoothing > 0に設定した場合、ignore_index が機能していない疑惑がある。ので確認してみる。

結論 (TL;DR)

torch.nn.functional.cross_entropy(Input, Target, ...) において、

  • weight (size=[tex:(C,)])を指定、
  • ignore_index をクラスのサイズ C 以外のインデックスに指定して (e.g. -100, -255)、
  • label_smoothing > 0,
  • reduction='mean'

と指定したとき(超限定的)、内部で呼び出されるtorch._C._nn.cross_entropy_lossTarget内のignore_indexの除外がされることなくweight[Target] と範囲外の参照が行われて(多分)、Out of boundsする。

ignore_index=-100が想定されているのにエラーが出るってことはバグなんだろう。 既出なのかわからんがIssueに投げた(一応調べた)。

https://github.com/pytorch/pytorch/issues/91383

またignore_index=-100で無視してほしい場所にも、label_smoothing/n_classes 分?の最低限なノイズが入る様子。対してignore_index=2のように既存のクラス  [0, N_\mathrm{class}=3) に含まれるインデックスを指定すると、しっかりとignoreされ(smoothingのノイズなしで)lossが0になる。

ignore_index 0 \leq ignore_index  \lt N_{\mathrm{class}}=3 かそれ以外(ignore_index  \lt 0,  N_\mathrm{class} \leq ignore_index)かでも変わるみたい?で一貫性がないので多分バグ (Issueにまだ投げていない)。


詳細・検証 (再現)

以下、詳細(長い)

環境

  • python=3.7.11,
  • pytorch=1.10,

Input, Targetの準備, 素のCrossEntropyLoss

import torch
import numpy as np
import sys
print("python={}".format(sys.version_info) )
print("torch.__version__={}".format(torch.__version__))
# python=sys.version_info(major=3, minor=7, micro=11, releaselevel='final', serial=0)
# torch.__version__=1.10.0

クラスを3つ( C=3 )にして3つのサンプル( N=3 )で、すべてをクラス0だと予測したとする。

# assuming 3 samples (N=3), three-casss (C=3).
N=3 # samples (mini-batch size)
C=3 # 3 calss
_ones = torch.ones(N,1,H,W)
_zeros = torch.zeros(N,1,H,W)
X = torch.cat([_ones, _zeros, _zeros], dim=1) # respond all sample as class 0
X, X.shape
# torch.Size([3, 3]))

Out:

(tensor([[1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.]]),
 torch.Size([3, 3]))

正解値Yはサンプル0, 1, 2でそれぞれクラス0,1,2として、サンプル0のみ正解となる。

Y = torch.tensor([0,1,2]).type(torch.LongTensor)
Y # tensor([0, 1, 2])

損失を計算してみると、 TPで 0.5514, FNで1.5514 になる様子。

loss_mean = torch.nn.functional.cross_entropy(X,Y, weight=None, ignore_index=-100, label_smoothing=0)
loss = torch.nn.functional.cross_entropy(X,Y, weight=None, ignore_index=-100, label_smoothing=0, reduction='none')
print("Loss is {}, ".format(loss_mean))
loss

Out:

Loss is 1.2181113958358765, 
tensor([0.5514, 1.5514, 1.5514])

1→1 (TP)でloss=0.5514 (サンプル0), 1→0 (FN)でloss=1.5514 (サンプル1,2) となる様子。

ignore_indexを使ってみる (ignore_index=-100, label_smoothing=0)

サンプル2 のみignoreされるはず。

ignore_index = -100

X_ignore = X.clone()
X_ignore[2,:] = ignore_index
X_ignore, X_ignore.shape

Y_ignore = Y.clone()
Y_ignore[2,] = ignore_index
Y_ignore, Y_ignore.shape

Out:

# X_ignore
(tensor([[   1.,    0.,    0.],
         [   1.,    0.,    0.],
         [-100., -100., -100.]]),
 torch.Size([3, 3]))
# Y_ignore
(tensor([   0,    1, -100]), torch.Size([3]))

lossがignoreされたサンプル2のみ0.000になっている(正常にignoreされている)。

loss_mean = torch.nn.functional.cross_entropy(X_ignore,Y_ignore, weight=None, ignore_index=-100, label_smoothing=0)
loss = torch.nn.functional.cross_entropy(X_ignore,Y_ignore, weight=None, ignore_index=ignore_index, label_smoothing=0, reduction='none')
print("Loss is {}.".format(loss_mean))
loss

Out:

Loss is 1.05144464969635.
tensor([0.5514, 1.5514, 0.0000])

gnore_index=-100, label_smoothing=0.1

ignoreしてsmoothingしたときの損失を計算してみる。

label_smoothingでは全体的に [tex: (\alpha\mathrm{label_smoothing})/N\mathrm{class}]のノイズが付与されるようなので、全体的に損失がの増減が予想される。(Y = Y * (1 -label_smoothing) +label_smoothing/n_classses)

loss_mean = torch.nn.functional.cross_entropy(X_ignore, Y_ignore, weight=None, ignore_index=-100, label_smoothing=0.1)
loss = torch.nn.functional.cross_entropy(X_ignore,Y_ignore, weight=None, ignore_index=ignore_index, label_smoothing=0.1, reduction='none')
print("Loss is {}.".format(loss_mean))
loss

Out:

Loss is 1.064128041267395.
tensor([0.6181, 1.5181, 0.1099])

サンプル1の 1⇒1 (TP)でloss=0.6181, サンプル2の 1⇒0 (FN)で loss=1.5181, サンプル3の 1⇒0 (FN, ignored)でloss=0.1099となっている。

ignore しろって言ってもノイズは適用されるらしい。

ひとまずsmoothingしててもignore_indexは機能している様子。

ignore_index以外の値が含まれていた時

ignore_index=-100と指定したのに-255が入っていたら、もちろんIndexErrorになる(Out of bounds)。

_Y_ignore255 = Y_ignore.clone()
_Y_ignore255[Y_ignore==-100] = -255 # ignore_index: -100 -> -255
# -255 is in Input and Target, but ignore_index=-100
loss = torch.nn.functional.cross_entropy(X_ignore,_Y_ignore255, weight=None, ignore_index=-100, label_smoothing=0.1, reduction='none')

Out:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipykernel_2483606/1826466474.py in <module>
      2 _Y_ignore255[Y_ignore==-100] = -255 # ignore_index: -100 -> -255
      3 # -255 is in Input and Target, but ignore_index=-100
----> 4 loss_noneReduction = torch.nn.functional.cross_entropy(X_ignore,_Y_ignore255, weight=None, ignore_index=-100, label_smoothing=0.1, reduction='none')

/Anaconda/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   2844     if size_average is not None or reduce is not None:
   2845         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2846     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
   2847 
   2848 

IndexError: Target -255 is out of bounds.

weightも使ってみる (weight=[1,1,1], ignore_index=-100, label_smoothing>0)

weights を各クラスに1として、フラットな重みを設定する(意味ないね)。

weights = torch.ones(C)
weights, weights.shape

Out

(tensor([1., 1., 1.]), torch.Size([3]))

weight=[1,1,1], ignore_index=-100, label_smoothing>0, reduction='none'

loss = torch.nn.functional.cross_entropy(X_ignore,Y_ignore, weight=weights, ignore_index=ignore_index, label_smoothing=0.1, reduction='none')
loss

Out:

tensor([0.6181, 1.5181, 0.1099])

weight=[1,1,1], ignore_index=-100, label_smoothing>0, **reduction='mean'**

loss_mean = torch.nn.functional.cross_entropy(X_ignore,Y_ignore, weight=weights, ignore_index=-100, label_smoothing=0.1, reduction='mean')

Out:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_2483606/2859958629.py in <module>
----> 1 loss_mean = torch.nn.functional.cross_entropy(X_ignore,Y_ignore, weight=weights, ignore_index=-100, label_smoothing=0.1, reduction='mean')

/Anaconda/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   2844     if size_average is not None or reduce is not None:
   2845         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2846     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
   2847 
   2848 

RuntimeError: index -100 is out of bounds for dimension 0 with size 3

RuntimeError: index -100 is out of bounds for dimension 0 with size 3が出た。weightに指定したtensorのインデックス外を参照している様子。weightつけてignore_indexlabel_smoothing使ってreduction=’mean’としたときのみのバグっぽい(reduction=’sum’, or ‘none’ではエラーが出なかった)。

デフォルト値が ignore_index=-100 なので、クラスのインデックス $[0, C)$を超える値と負値も想定しているはずなので、単純なバグっぽい。

torch 1.11.0ではどうなっているかは知らんが1.10.0もまだまだ現役なので治っててほしい(label_smoothingが効果があるかは微妙だが)。

ignore_index [0, N_\mathrm{class}) かそれ以外かでのlabel_smoothの挙動

ignore_index=-255,+100 [0, N_\mathrm{class}) 外 のときでも、サンプル2のloss=1.0986なのでノイズが入っている。

_Y_ignore255 = Y_ignore.clone()
_Y_ignore255[Y_ignore==-100] = -255 # ignore_index: -100 -> -255
# -255 is in Input and Target, but ignore_index=-100
loss = torch.nn.functional.cross_entropy(X_ignore,_Y_ignore255, weight=None, ignore_index=-100, label_smoothing=0.1, reduction='none')
loss # tensor([0.6181, 1.5181, 1.0986])
_Y_ignore_100 = Y_ignore.clone()
_Y_ignore_100[Y_ignore==-100] = 100 # ignore_index: -100 -> 100 not in [0, N_class=3)
_loss  = torch.nn.functional.cross_entropy(X_ignore,_Y_ignore_2, weight=torch.ones(C), ignore_index=100, label_smoothing=0.1, reduction='none')
_loss # tensor([0.6181, 1.5181, 1.0986])

そうしようかと思いきや、ignore_index=2 in  [0, N_\mathrm{class}) のとき、ignore_indexのサンプル2の損失が0なのでsmoothingのノイズは無し。

_Y_ignore_2 = Y_ignore.clone()
_Y_ignore_2[Y_ignore==-100] = 2 # ignore_index: -100 -> 0 in [0, N_class=3)
_loss  = torch.nn.functional.cross_entropy(X_ignore,_Y_ignore_2, weight=torch.ones(C), ignore_index=2, label_smoothing=0.1, reduction='none')
_loss # tensor([0.6181, 1.5181, 0.0000])

さすがに一貫性がないので、これもバグだと思う。