timmを使ってMixup/CutMixを手軽に実装する方法
timm(PyTorch Image Models)には、mixupやcutmixといったデータ拡張手法がライブラリの機能として用意されています。これを使えば、自身の学習コードにmixup/cutmixを簡単に実装することが可能です。この記事では、MixupやCutMixの2つを利用する方法について解説します。
Mixup/CutMixとは
mixupとcutmixは、複数の画像をブレンドするデータ拡張手法です。
これらの手法は、画像分類タスクのモデル性能を向上させるのに非常に有効なため、ディープラーニングでは、トレーニング(学習)で多く活用されています。
Mixupは2つのサンプルを線形に組み合わせることでデータ拡張を行います。具体的には2つの画像を選択し、それぞれを一定比率でブレンドします。このとき、正解ラベルもブレンド比率に従って再計算します
CutMixは、画像の一部を切り取って他の画像に挿入することでデータ拡張を行う手法です。具体的には、1つ目の画像から矩形の領域をランダムに選択し、矩形内を2つ目の画像に置き換えます。正解ラベルは、領域の面積比率に基づいて再計算します。
この2つはディープラーニングのトレーニングにおいて非常に有用な手法です。
また、これらの手法はtimm(Pytorch Image Image Models)のライブラリの機能として用意されているため簡単に自身の学習コードに組み込むことが可能です。
ライブラリを利用すれば実装も難しくないので、積極的に使っていきたい手法ではないでしょうか。
この記事では、timmを使ったmixup/cutmixの組み込み方法について解説します。
Timmのデータ拡張(Mixup)
関数
Mixup/Cutmixを行う関数はtimm.data.mixup.Mixup
という関数に集約されています。このため、Mixup
という関数を呼び出すだけで、mixup/cutmixのどちらも使うことができます。
この関数を利用するには、まずは以下のようにインポートを行います。
from timm.data.mixup import Mixup
パラメータ
Mixup関数のパラメータ(引数)は以下の通りです。
パラメータ名 | 説明 |
mixup_alpha | mixupのブレンド値を設定します |
cutmix_alpha | cutmixのブレンド値を設定します |
cutmix_minmax | None以外の場合は、cutmixのmin/max比を設定となります |
prob | mixup/cutmixを実行する割合を指定します |
switch_prob | cutmixを選ぶ割合を指定します。0.5のき均等になります |
mode | batch/pair/elem/halfから選択します batchは、ブレンド値やcutmixの領域などがバッチ毎に、pairはペア毎に、elemは要素毎に変わります |
label_smoothing | ラベルスムージングを行う場合に設定します |
num_classes | クラス数を設定します |
使用方法
以下使用方法のサンプルです。サンプル中のinputsはtorch.Size([4, 3, 768, 512])
で、labelsはtensor([0, 1, 0, 1])
とします。
Mixupを行う場合
mixup_alpha=1
、cutmix_alpha=0
を指定するとmixupのみを行います。今回は、わかりやすくするためにnum_classes=3
としています。
mixup_args = {
'mixup_alpha': 1.,
'cutmix_alpha': 0.,
'cutmix_minmax': None,
'prob': 1.0,
'switch_prob': 0.,
'mode': 'batch',
'label_smoothing': 0,
'num_classes': 3}
inputs = torch.stack([img0, img1, img2, img3])
labels = torch.tensor([0, 1, 0, 1])
out = torchvision.utils.make_grid(inputs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
mixup_fn = Mixup(**mixup_args)
res_imgs, res_labels = mixup_fn(inputs, labels)
out = torchvision.utils.make_grid(res_imgs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
print(res_labels)
実行結果は以下のようになります。
画像がブレンドされ、比率に従ってラベルも変化していることが分かります。なおmixupすると正解ラベルは、one-hotと同様にクラス数分の列があるテンソルになります。
tensor([[0.9089, 0.0911, 0.0000],
[0.0911, 0.9089, 0.0000],
[0.9089, 0.0911, 0.0000],
[0.0911, 0.9089, 0.0000]])
Mixup(ラベルスムージングあり)
mixupと同時にラベルスムージングを行う例です。ここでは、label_smoothing=0.1
と設定しています。
mixup_args = {
'mixup_alpha': 1.,
'cutmix_alpha': 0.,
'cutmix_minmax': None,
'prob': 1.0,
'switch_prob': 0.,
'mode': 'batch',
'label_smoothing': 0.1,
'num_classes': 5}
inputs = torch.stack([img0, img1, img2, img3])
labels = torch.tensor([0, 1, 0, 1])
out = torchvision.utils.make_grid(inputs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
mixup_fn = Mixup(**mixup_args)
res_imgs, res_labels = mixup_fn(inputs, labels)
out = torchvision.utils.make_grid(res_imgs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
print(res_labels)
ブレンド後のラベルを見るとクラス0, 1以外のクラスも0.02が設定されています。0.1を5クラスに均等に分割した重みが全てのクラスに加算されているためです。
tensor([[0.5034, 0.4366, 0.0200, 0.0200, 0.0200],
[0.4366, 0.5034, 0.0200, 0.0200, 0.0200],
[0.5034, 0.4366, 0.0200, 0.0200, 0.0200],
[0.4366, 0.5034, 0.0200, 0.0200, 0.0200]])
Cutmixを行う場合
Cutmixを行う例です。cutmix_aplha=1
に設定し、mixup_alpha=0
に設定することでcutmixのみ行われるようになります。
mixup_args = {
'mixup_alpha': 0.,
'cutmix_alpha': 1.0,
'cutmix_minmax': None,
'prob': 1.0,
'switch_prob': 0.,
'mode': 'batch',
'label_smoothing': 0,
'num_classes': 3}
inputs = torch.stack([img0, img1, img2, img3])
labels = torch.tensor([0, 1, 0, 1])
out = torchvision.utils.make_grid(inputs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
mixup_fn = Mixup(**mixup_args)
res_imgs, res_labels = mixup_fn(inputs, labels)
out = torchvision.utils.make_grid(res_imgs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
print(res_labels)
mixupと異なり矩形領域が他の画像と入れ替わっていることが分かります。mixupとcutmixは同じように2つの画像をブレンディングするデータ拡張ですが、この画像を見ると手法の違いがよく理解できます。
tensor([[0.6405, 0.3595, 0.0000],
[0.3595, 0.6405, 0.0000],
[0.6405, 0.3595, 0.0000],
[0.3595, 0.6405, 0.0000]])
まとめ
以上、timmのmixup/cutmixの使い方について解説しました。トレーニングコードに、この関数を追加するだけで簡単にmixup/cutmixを追加できます。簡単に実装できる割には精度向上が期待できるデータ拡張です。