MOB-LOG

モブおじの記録 (Programming, 統計・機械学習)

(メモ) torch.nn.functional.cross_entropy, 損失計算のweightの設定

(チラ裏の日記、こんなことしてたなあとなります)

クラスごとのピクセル数に偏りがあるため,ピクセル数が多いクラスほど軽視するよう (多いクラスを正解してもlossが余り減らない)に設定していた.(と思っていた)

損失関数 torch.nn.functional.cross_entropy のクラスごとの重みについて CrossEntropyLoss - PyTorch 1.11.0 documentation

weight (Tensor, optional) – a manual rescaling weight given to each class. If given, has to be a Tensor of size C

を見間違えて ‘has to be a Tensor of size of C‘、つまりクラスCのサイズ(サンプル数)を指定しろ、という言うようなことが書かれていると思って、そのままクラスごとのサンプル数を与えていた。

正しくは、Cはクラスの数(number of classes)で、普通に

where  x is the input,  y is the target,  w is the weight,  C is the number of classes, and  N spans the minibatch dimension as well as  d_1, ..., d_k for the  K-dimensional case. If reduction is not 'none' (default 'mean'), then …

と書かれていて wの値がそのまま乗算される。

つまりImbalanced dataならピクセル数が多いほど weightを小さく設定すべき.

CrossEntropyLoss - PyTorch 1.11.0 documentation

しかも reduction=’sum’だとウェイトが適用されないらしい。

過去に使ったときは正しく設定していたけど、その時々でドキュメント適当に読んで「ほな、こうか」としがちで時々間違える。 ドキュメントはちゃんと(ただしく)読みましょう。