機械学習
記事内に商品プロモーションを含む場合があります

timmを使ってMixup/CutMixを手軽に実装する方法

tadanori

timmを使うとmixup/cutmixを簡単に実装することができます。この記事ではtimmを使ったmixup/cutmixという2つのデータ拡張の実装方法について解説します。

Mixup/CutMixとは

mixupとcutmixは、複数の画像をブレンドするデータ拡張手法です。これらの手法は、画像分類タスクのモデル性能を向上させるのに非常に有効なため、ディープラーニングのトレーニングで広く用いられています。

1つめのMixupは2つのサンプルを線形に組み合わせることでデータ拡張を行います。具体的には2つの画像を選択し、それぞれを一定比率でブレンドします。このとき、正解ラベルもブレンド比率に従って再計算します

Cutmixは、画像の一部を切り取って他の画像に挿入することでデータ拡張を行う手法です。具体的には、1つ目の画像から矩形の領域をランダムに選択し、矩形内を2つ目の画像に置き換えます。正解ラベルは、領域の面積比率に基づいて再計算します。

この2つはディープラーニングのトレーニングに非常に有用な手法ですが、timm(Pytorch Image Image Models)のライブラリの機能として用意されているため簡単に実装することができます。

この記事では、timmを使ったmixup/cutmixの実装方法について解説します。

自分でmixupを実装したい方はこちら
データ拡張(data augmentation)手法のmixupを解説|Pytorch 【初級 深層学習講座】
データ拡張(data augmentation)手法のmixupを解説|Pytorch 【初級 深層学習講座】

Timmのデータ拡張(Mixup)

関数

Mixup/Cutmixを行う関数はtimm.data.mixup.Mixupという関数に集約されています。利用する場合は以下のようにインポートします

from timm.data.mixup import Mixup

パラメータ

Mixupのパラメータは以下になります。

パラメータ名説明
mixup_alphamixupのブレンド値を設定します
cutmix_alphacutmixのブレンド値を設定します
cutmix_minmaxNone以外の場合は、cutmixのmin/max比を設定となります
probmixup/cutmixを実行する割合を指定します
switch_probcutmixを選ぶ割合を指定します。0.5のき均等になります
modebatch/pair/elem/halfから選択します
batchは、ブレンド値やcutmixの領域などがバッチ毎に、pairはペア毎に、elemは要素毎に変わります
label_smoothingラベルスムージングを行う場合に設定します
num_classesクラス数を設定します

使用方法

以下使用方法のサンプルです。サンプル中のinputsはtorch.Size([4, 3, 768, 512])で、labelsはtensor([0, 1, 0, 1])とします。

Mixup

mixup_alpha=1cutmix_alpha=0を指定するとmixupのみを行います。今回は、わかりやすくするためにnum_classes=3としています。

mixup_args = {
    'mixup_alpha': 1.,
    'cutmix_alpha': 0.,
    'cutmix_minmax': None,
    'prob': 1.0,
    'switch_prob': 0.,
    'mode': 'batch',
    'label_smoothing': 0,
    'num_classes': 3}

inputs = torch.stack([img0, img1, img2, img3])
labels = torch.tensor([0, 1, 0, 1])

out = torchvision.utils.make_grid(inputs)
plt.imshow(out.permute(1, 2, 0))
plt.show()

mixup_fn = Mixup(**mixup_args)
res_imgs, res_labels = mixup_fn(inputs, labels)

out = torchvision.utils.make_grid(res_imgs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
print(res_labels)

実行結果は以下のようになります。

original
result1

画像がブレンドされ、比率に従ってラベルが変化していることが分かります。なおmixupすると正解ラベルは、one-hotと同様にクラス数分の列になります。

tensor([[0.9089, 0.0911, 0.0000],
        [0.0911, 0.9089, 0.0000],
        [0.9089, 0.0911, 0.0000],
        [0.0911, 0.9089, 0.0000]])

Mixup(ラベルスムージングあり)

mixupと同時にラベルスムージングを行う例です。ここでは、label_smoothing=0.1と設定しています。

mixup_args = {
    'mixup_alpha': 1.,
    'cutmix_alpha': 0.,
    'cutmix_minmax': None,
    'prob': 1.0,
    'switch_prob': 0.,
    'mode': 'batch',
    'label_smoothing': 0.1,
    'num_classes': 5}

inputs = torch.stack([img0, img1, img2, img3])
labels = torch.tensor([0, 1, 0, 1])

out = torchvision.utils.make_grid(inputs)
plt.imshow(out.permute(1, 2, 0))
plt.show()

mixup_fn = Mixup(**mixup_args)
res_imgs, res_labels = mixup_fn(inputs, labels)

out = torchvision.utils.make_grid(res_imgs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
print(res_labels)
original
result2

ブレンド後のラベルを見るとクラス0, 1以外のクラスも0.02が設定されています。0.1を5クラスに均等に分割した重みが全てのクラスに加算されているためです。

tensor([[0.5034, 0.4366, 0.0200, 0.0200, 0.0200],
        [0.4366, 0.5034, 0.0200, 0.0200, 0.0200],
        [0.5034, 0.4366, 0.0200, 0.0200, 0.0200],
        [0.4366, 0.5034, 0.0200, 0.0200, 0.0200]])

Cutmix

Cutmixを行う例です。cutmix_aplha=1に設定し、mixup_alpha=0に設定することでcutmixのみ行われるようになります。

mixup_args = {
    'mixup_alpha': 0.,
    'cutmix_alpha': 1.0,
    'cutmix_minmax': None,
    'prob': 1.0,
    'switch_prob': 0.,
    'mode': 'batch',
    'label_smoothing': 0,
    'num_classes': 3}

inputs = torch.stack([img0, img1, img2, img3])
labels = torch.tensor([0, 1, 0, 1])

out = torchvision.utils.make_grid(inputs)
plt.imshow(out.permute(1, 2, 0))
plt.show()

mixup_fn = Mixup(**mixup_args)
res_imgs, res_labels = mixup_fn(inputs, labels)

out = torchvision.utils.make_grid(res_imgs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
print(res_labels)
original

mixupと異なり矩形領域が他の画像と入れ替わっていることが分かります。

result3
tensor([[0.6405, 0.3595, 0.0000],
        [0.3595, 0.6405, 0.0000],
        [0.6405, 0.3595, 0.0000],
        [0.3595, 0.6405, 0.0000]])

まとめ

以上、timmのmixup/cutmixの使い方について解説しました。トレーニングコードに、この関数を追加するだけで簡単にmixup/cutmixを追加できます。簡単に実装できる割には精度向上が期待できるデータ拡張です。

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました