timmを使ってMixup/CutMixを手軽に実装する方法
![](https://tech.aru-zakki.com/wp-content/uploads/2024/05/mixup-cutmix.001.jpeg)
timmを使うとmixup/cutmixを簡単に実装することができます。この記事ではtimmを使ったmixup/cutmixという2つのデータ拡張の実装方法について解説します。
Mixup/CutMixとは
mixupとcutmixは、複数の画像をブレンドするデータ拡張手法です。これらの手法は、画像分類タスクのモデル性能を向上させるのに非常に有効なため、ディープラーニングのトレーニングで広く用いられています。
1つめのMixupは2つのサンプルを線形に組み合わせることでデータ拡張を行います。具体的には2つの画像を選択し、それぞれを一定比率でブレンドします。このとき、正解ラベルもブレンド比率に従って再計算します
Cutmixは、画像の一部を切り取って他の画像に挿入することでデータ拡張を行う手法です。具体的には、1つ目の画像から矩形の領域をランダムに選択し、矩形内を2つ目の画像に置き換えます。正解ラベルは、領域の面積比率に基づいて再計算します。
この2つはディープラーニングのトレーニングに非常に有用な手法ですが、timm(Pytorch Image Image Models)のライブラリの機能として用意されているため簡単に実装することができます。
この記事では、timmを使ったmixup/cutmixの実装方法について解説します。
![データ拡張(data augmentation)手法のmixupを解説|Pytorch 【初級 深層学習講座】](https://tech.aru-zakki.com/wp-content/uploads/2024/05/pytorch-mixup.001-320x180.jpeg)
Timmのデータ拡張(Mixup)
関数
Mixup/Cutmixを行う関数はtimm.data.mixup.Mixup
という関数に集約されています。利用する場合は以下のようにインポートします
from timm.data.mixup import Mixup
パラメータ
Mixupのパラメータは以下になります。
パラメータ名 | 説明 |
mixup_alpha | mixupのブレンド値を設定します |
cutmix_alpha | cutmixのブレンド値を設定します |
cutmix_minmax | None以外の場合は、cutmixのmin/max比を設定となります |
prob | mixup/cutmixを実行する割合を指定します |
switch_prob | cutmixを選ぶ割合を指定します。0.5のき均等になります |
mode | batch/pair/elem/halfから選択します batchは、ブレンド値やcutmixの領域などがバッチ毎に、pairはペア毎に、elemは要素毎に変わります |
label_smoothing | ラベルスムージングを行う場合に設定します |
num_classes | クラス数を設定します |
使用方法
以下使用方法のサンプルです。サンプル中のinputsはtorch.Size([4, 3, 768, 512])
で、labelsはtensor([0, 1, 0, 1])
とします。
Mixup
mixup_alpha=1
、cutmix_alpha=0
を指定するとmixupのみを行います。今回は、わかりやすくするためにnum_classes=3
としています。
mixup_args = {
'mixup_alpha': 1.,
'cutmix_alpha': 0.,
'cutmix_minmax': None,
'prob': 1.0,
'switch_prob': 0.,
'mode': 'batch',
'label_smoothing': 0,
'num_classes': 3}
inputs = torch.stack([img0, img1, img2, img3])
labels = torch.tensor([0, 1, 0, 1])
out = torchvision.utils.make_grid(inputs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
mixup_fn = Mixup(**mixup_args)
res_imgs, res_labels = mixup_fn(inputs, labels)
out = torchvision.utils.make_grid(res_imgs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
print(res_labels)
実行結果は以下のようになります。
![original](https://tech.aru-zakki.com/wp-content/uploads/2024/05/image-15.png)
![result1](https://tech.aru-zakki.com/wp-content/uploads/2024/05/image-16.png)
画像がブレンドされ、比率に従ってラベルが変化していることが分かります。なおmixupすると正解ラベルは、one-hotと同様にクラス数分の列になります。
tensor([[0.9089, 0.0911, 0.0000],
[0.0911, 0.9089, 0.0000],
[0.9089, 0.0911, 0.0000],
[0.0911, 0.9089, 0.0000]])
Mixup(ラベルスムージングあり)
mixupと同時にラベルスムージングを行う例です。ここでは、label_smoothing=0.1
と設定しています。
mixup_args = {
'mixup_alpha': 1.,
'cutmix_alpha': 0.,
'cutmix_minmax': None,
'prob': 1.0,
'switch_prob': 0.,
'mode': 'batch',
'label_smoothing': 0.1,
'num_classes': 5}
inputs = torch.stack([img0, img1, img2, img3])
labels = torch.tensor([0, 1, 0, 1])
out = torchvision.utils.make_grid(inputs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
mixup_fn = Mixup(**mixup_args)
res_imgs, res_labels = mixup_fn(inputs, labels)
out = torchvision.utils.make_grid(res_imgs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
print(res_labels)
![original](https://tech.aru-zakki.com/wp-content/uploads/2024/05/image-17.png)
![result2](https://tech.aru-zakki.com/wp-content/uploads/2024/05/image-18.png)
ブレンド後のラベルを見るとクラス0, 1以外のクラスも0.02が設定されています。0.1を5クラスに均等に分割した重みが全てのクラスに加算されているためです。
tensor([[0.5034, 0.4366, 0.0200, 0.0200, 0.0200],
[0.4366, 0.5034, 0.0200, 0.0200, 0.0200],
[0.5034, 0.4366, 0.0200, 0.0200, 0.0200],
[0.4366, 0.5034, 0.0200, 0.0200, 0.0200]])
Cutmix
Cutmixを行う例です。cutmix_aplha=1
に設定し、mixup_alpha=0
に設定することでcutmixのみ行われるようになります。
mixup_args = {
'mixup_alpha': 0.,
'cutmix_alpha': 1.0,
'cutmix_minmax': None,
'prob': 1.0,
'switch_prob': 0.,
'mode': 'batch',
'label_smoothing': 0,
'num_classes': 3}
inputs = torch.stack([img0, img1, img2, img3])
labels = torch.tensor([0, 1, 0, 1])
out = torchvision.utils.make_grid(inputs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
mixup_fn = Mixup(**mixup_args)
res_imgs, res_labels = mixup_fn(inputs, labels)
out = torchvision.utils.make_grid(res_imgs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
print(res_labels)
![original](https://tech.aru-zakki.com/wp-content/uploads/2024/05/image-19.png)
mixupと異なり矩形領域が他の画像と入れ替わっていることが分かります。
![result3](https://tech.aru-zakki.com/wp-content/uploads/2024/05/image-20.png)
tensor([[0.6405, 0.3595, 0.0000],
[0.3595, 0.6405, 0.0000],
[0.6405, 0.3595, 0.0000],
[0.3595, 0.6405, 0.0000]])
まとめ
以上、timmのmixup/cutmixの使い方について解説しました。トレーニングコードに、この関数を追加するだけで簡単にmixup/cutmixを追加できます。簡単に実装できる割には精度向上が期待できるデータ拡張です。