Pytorch|ラベル平滑化(label smoothing)の実装方法【初級 深層学習講座】
![](https://tech.aru-zakki.com/wp-content/uploads/2024/05/pytorch-labelsmooting.001.jpeg)
この記事では、ラベル平滑化と呼ばれるテクニックについて解説します。ラベル平滑化は、クラス分類タスクでモデルの性能を向上させるテクニックの1つです。
ラベル平滑化(label smoothingとは)
ラベル平滑化(Label Smoothing)は、機械学習や深層学習において、モデルの性能を向上させるために使われるテクニックの一つです。
ラベル平滑化の目的は、学習時の過剰適合(overfitting)の抑制です。
ここでは、具体的にどのような操作を行うかを解説します。
通常のクラス分類タスクの学習では、正解ラベルを1として学習を行います。具体的には、正解ラベルには1を、他のクラスには0を確信度として与えて学習させます(下図右)。
ラベル平滑化では、正解ラベルの確信度を少し下げ、その代わりに他のラベルの確信度をある程度与えて学習を行います。
下図(左)は、ラベル平滑化を行なった例です。ラベル平滑化では、正解ラベルの値を1から0.9のように少し下げ、残りの0.1を他のラベルに割り振ります。
![ラベル平滑化の例](https://tech.aru-zakki.com/wp-content/uploads/2024/05/image-6-1024x241.png)
これにより、ラベル平滑化では、モデルがデータに過度に適合するのを防ぎ、一般化能力を高めることができます。
具体的な計算
ラベルスムージングは以下の計算で行うことができます。
正解ラベルがgt = [0,0,1,0,0]
で与えられているとすると、以下の式で表すことができます。
$$
gt = gt * (1-\epsilon) + \frac{\epsilon}{K}
$$
式中の$\epsilon$は、確率分布を滑らかにするために使用されるパラメータです。0から1の間の値を取ります。$K$はクラス数です。
この式は以下の図のようなイメージになります。
![](https://tech.aru-zakki.com/wp-content/uploads/2024/05/image-7-1024x256.png)
ラベル平滑化のプログラム例(1)
プログラムでは以下になります。
import torch
e = 0.1
k = 5
gt = torch.tensor([0, 0, 1, 0, 0], dtype=torch.float32)
gt = gt*(1-e) + e/k
print(gt)
# tensor([0.0200, 0.0200, 0.9200, 0.0200, 0.0200])
ラベル平滑化のプログラム例(2)
正解ラベルが数値で与えられている場合は、one-hotに直して同様の処理を行います。
import torch
e = 0.1
k = 5
gt = torch.tensor([2])
gt = torch.nn.functional.one_hot(gt, num_classes=k) * (1-e) + e/k
print(gt)
# tensor([0.0200, 0.0200, 0.9200, 0.0200, 0.0200])
Pytorchの実装
CrossEntropyLoss
PyTorchのCrossEntropyLoss
には、ラベルスムージングを行うためのパラメータがあります。ここに$\epsilon$を代入することでラベルスムージングすることができます。
import torch
from torch.nn import CrossEntropyLoss
gt = torch.tensor([0, 0, 1, 0, 0], dtype=torch.float32)
pred = torch.tensor([0.1, 0.1, 0.8, 0.1, 0.1], dtype=torch.float32)
loss = CrossEntropyLoss(label_smoothing=0.1)
print(loss(pred, gt))
# tensor(1.1500)
ラベル平滑化を自分で行った場合
CrossEntropyLoss
のラベル平滑化のパラメータの動きを確認するために、自分でラベルスムージングを行ってCrossEntropyLoss
を実行してみます。
下のプログラムでは、この記事で説明した通り、ラベル平滑化を行ってからCrossEntropyLoss
を実行しています。結果は、パラメータとして0.1を渡した場合と同じになります。
e = 0.1
k = 5
gt = torch.tensor([0, 0, 1, 0, 0], dtype=torch.float32)
gt = gt*(1-e) + e/k
print(gt)
# tensor([0.0200, 0.0200, 0.9200, 0.0200, 0.0200])
pred = torch.tensor([0.1, 0.1, 0.8, 0.1, 0.1], dtype=torch.float32)
print(pred)
# tensor([0.1000, 0.1000, 0.8000, 0.1000, 0.1000])
print(gt)
# tensor([0.0200, 0.0200, 0.9200, 0.0200, 0.0200])
loss = CrossEntropyLoss()
loss(pred, gt)
# tensor(1.1500)
![](https://tech.aru-zakki.com/wp-content/uploads/2023/06/tabbycat.png)
自分でラベル平滑化する方法は、ラベル平滑化をサポートしていない損失関数を利用する場合に使うことができるので、覚えておくと便利です。
まとめ
この記事ではラベル平滑化(label smoothing)について解説しました。簡単な処理で効果がありますので覚えておいて損はないと思います。