機械学習
記事内に商品プロモーションを含む場合があります

timmを使ってMixup/CutMixを手軽に実装する方法

Aru

timm(PyTorch Image Models)には、mixupやcutmixといったデータ拡張手法がライブラリの機能として用意されています。これを使えば、自身の学習コードにmixup/cutmixを簡単に実装することが可能です。この記事では、MixupやCutMixの2つを利用する方法について解説します。

Mixup/CutMixとは

mixupとcutmixは、複数の画像をブレンドするデータ拡張手法です。

これらの手法は、画像分類タスクのモデル性能を向上させるのに非常に有効なため、ディープラーニングでは、トレーニング(学習)で多く活用されています。

Mixupは2つのサンプルを線形に組み合わせることでデータ拡張を行います。具体的には2つの画像を選択し、それぞれを一定比率でブレンドします。このとき、正解ラベルもブレンド比率に従って再計算します

CutMixは、画像の一部を切り取って他の画像に挿入することでデータ拡張を行う手法です。具体的には、1つ目の画像から矩形の領域をランダムに選択し、矩形内を2つ目の画像に置き換えます。正解ラベルは、領域の面積比率に基づいて再計算します。

この2つはディープラーニングのトレーニングにおいて非常に有用な手法です。

また、これらの手法はtimm(Pytorch Image Image Models)のライブラリの機能として用意されているため簡単に自身の学習コードに組み込むことが可能です。

ライブラリを利用すれば実装も難しくないので、積極的に使っていきたい手法ではないでしょうか。

この記事では、timmを使ったmixup/cutmixの組み込み方法について解説します。

自分でmixupを実装したい方はこちら
データ拡張(data augmentation)手法のmixupを解説|Pytorchでの実装方法【初級 深層学習講座】
データ拡張(data augmentation)手法のmixupを解説|Pytorchでの実装方法【初級 深層学習講座】

Timmのデータ拡張(Mixup)

関数

Mixup/Cutmixを行う関数はtimm.data.mixup.Mixupという関数に集約されています。このため、Mixupという関数を呼び出すだけで、mixup/cutmixのどちらも使うことができます

この関数を利用するには、まずは以下のようにインポートを行います。

from timm.data.mixup import Mixup

パラメータ

Mixup関数のパラメータ(引数)は以下の通りです。

パラメータ名説明
mixup_alphamixupのブレンド値を設定します
cutmix_alphacutmixのブレンド値を設定します
cutmix_minmaxNone以外の場合は、cutmixのmin/max比を設定となります
probmixup/cutmixを実行する割合を指定します
switch_probcutmixを選ぶ割合を指定します。0.5のき均等になります
modebatch/pair/elem/halfから選択します
batchは、ブレンド値やcutmixの領域などがバッチ毎に、pairはペア毎に、elemは要素毎に変わります
label_smoothingラベルスムージングを行う場合に設定します
num_classesクラス数を設定します

使用方法

以下使用方法のサンプルです。サンプル中のinputsはtorch.Size([4, 3, 768, 512])で、labelsはtensor([0, 1, 0, 1])とします。

Mixupを行う場合

mixup_alpha=1cutmix_alpha=0を指定するとmixupのみを行います。今回は、わかりやすくするためにnum_classes=3としています。

mixup_args = {
    'mixup_alpha': 1.,
    'cutmix_alpha': 0.,
    'cutmix_minmax': None,
    'prob': 1.0,
    'switch_prob': 0.,
    'mode': 'batch',
    'label_smoothing': 0,
    'num_classes': 3}

inputs = torch.stack([img0, img1, img2, img3])
labels = torch.tensor([0, 1, 0, 1])

out = torchvision.utils.make_grid(inputs)
plt.imshow(out.permute(1, 2, 0))
plt.show()

mixup_fn = Mixup(**mixup_args)
res_imgs, res_labels = mixup_fn(inputs, labels)

out = torchvision.utils.make_grid(res_imgs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
print(res_labels)

実行結果は以下のようになります。

original
result1

画像がブレンドされ、比率に従ってラベルも変化していることが分かります。なおmixupすると正解ラベルは、one-hotと同様にクラス数分の列があるテンソルになります。

tensor([[0.9089, 0.0911, 0.0000],
        [0.0911, 0.9089, 0.0000],
        [0.9089, 0.0911, 0.0000],
        [0.0911, 0.9089, 0.0000]])

Mixup(ラベルスムージングあり)

mixupと同時にラベルスムージングを行う例です。ここでは、label_smoothing=0.1と設定しています。

mixup_args = {
    'mixup_alpha': 1.,
    'cutmix_alpha': 0.,
    'cutmix_minmax': None,
    'prob': 1.0,
    'switch_prob': 0.,
    'mode': 'batch',
    'label_smoothing': 0.1,
    'num_classes': 5}

inputs = torch.stack([img0, img1, img2, img3])
labels = torch.tensor([0, 1, 0, 1])

out = torchvision.utils.make_grid(inputs)
plt.imshow(out.permute(1, 2, 0))
plt.show()

mixup_fn = Mixup(**mixup_args)
res_imgs, res_labels = mixup_fn(inputs, labels)

out = torchvision.utils.make_grid(res_imgs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
print(res_labels)
original
result2

ブレンド後のラベルを見るとクラス0, 1以外のクラスも0.02が設定されています0.1を5クラスに均等に分割した重みが全てのクラスに加算されているためです。

tensor([[0.5034, 0.4366, 0.0200, 0.0200, 0.0200],
        [0.4366, 0.5034, 0.0200, 0.0200, 0.0200],
        [0.5034, 0.4366, 0.0200, 0.0200, 0.0200],
        [0.4366, 0.5034, 0.0200, 0.0200, 0.0200]])

Cutmixを行う場合

Cutmixを行う例です。cutmix_aplha=1に設定し、mixup_alpha=0に設定することでcutmixのみ行われるようになります。

mixup_args = {
    'mixup_alpha': 0.,
    'cutmix_alpha': 1.0,
    'cutmix_minmax': None,
    'prob': 1.0,
    'switch_prob': 0.,
    'mode': 'batch',
    'label_smoothing': 0,
    'num_classes': 3}

inputs = torch.stack([img0, img1, img2, img3])
labels = torch.tensor([0, 1, 0, 1])

out = torchvision.utils.make_grid(inputs)
plt.imshow(out.permute(1, 2, 0))
plt.show()

mixup_fn = Mixup(**mixup_args)
res_imgs, res_labels = mixup_fn(inputs, labels)

out = torchvision.utils.make_grid(res_imgs)
plt.imshow(out.permute(1, 2, 0))
plt.show()
print(res_labels)
original

mixupと異なり矩形領域が他の画像と入れ替わっていることが分かります。mixupとcutmixは同じように2つの画像をブレンディングするデータ拡張ですが、この画像を見ると手法の違いがよく理解できます。

result3
tensor([[0.6405, 0.3595, 0.0000],
        [0.3595, 0.6405, 0.0000],
        [0.6405, 0.3595, 0.0000],
        [0.3595, 0.6405, 0.0000]])

まとめ

以上、timmのmixup/cutmixの使い方について解説しました。トレーニングコードに、この関数を追加するだけで簡単にmixup/cutmixを追加できます。簡単に実装できる割には精度向上が期待できるデータ拡張です。

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました