機械学習
記事内に商品プロモーションを含む場合があります

TTA(Test Time Augmentation)とは?PyTorchでの実装方法

Aru

TTA(Test Time Augmentation)は、ディープラーニングで推論の精度を向上させるテクニックです。データ拡張(Augmentation)は、通常は学習時に行いますが、TTAでは推論時に行います。この記事では、TTAのメリットとでデメリット、そしてTTAをPytorchで実装する方法について解説します。

TTA(Test Time Augmentation)とは

TTA(Test Time Augmentation)は、ディープラーニングや機械学習の分野で用いられるテクニックの一つです。一般的には、データ拡張(Data Augmentation)はモデルのトレーニング時に行われます。

TTAは推論時にデータ拡張を行う点が特徴です。

具体的には、TTAでは推論する画像に対して回転や拡大縮小、水平反転などの変換を施し、変換後の画像に対してもモデルによる推論を行います。最終的に、得られた推論結果を合計・平均・最大値・多数決などの方法で集計し、最終的な予測結果を導き出します。

TTAの主な利点は、予測精度を向上させることができる点です。

TTAには、ロバスト性を高める一方で、推論時の処理が増えるため、処理時間が長くなるという欠点があります。

しかし、処理時間が許容できる範囲であれば、精度アップが期待できるため、多くのケースで有用です。特に、Kaggleのコンペティションのように「少しでも性能を上げたい!」という場合には、TTAは有力な選択肢となります。

ここでは、PyTorch Image Models(TIMM)の学習済みモデルを使用して、TTAの実装方法を紹介します。

TTAの実装は比較的簡単で、気軽に試すことが可能です。精度アップの手段の1つとして、気軽に試せるのがメリットです。

なお、TTAとは異なるアプローチとして、複数の学習モデルの組み合わせる「Model soups」という手法があります。こちらは、速度低下なしに精度を向上させることが可能な手法です。

Model soupsについては、以下の記事を参考にしてください。

Model soupsについてはこちらの記事
重みの平均化で精度向上させるModel SoupsをPyTorchで実装する
重みの平均化で精度向上させるModel SoupsをPyTorchで実装する

もちろん、TTAとmodel soupsを組み合わせるて、さらなる精度アップを狙うことも可能です。

実装例

以下、resnet18の学習済みモデルを使ってTTAを実践してみます。

ライブラリのインポート

必要なライブラリをインポートします。

import torch
import timm
from PIL import Image
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt

Google Colabotoryで実行する場合には、timmをインストールする必要があります。

!pip install timm

imagenetのクラス名を読み込む

結果をわかりやすくするために、imagenetのクラス名の一覧を読み込みます。一覧は、以下のURLにあるのでダウンロードするか、git cloneを使ってコピーしてください。

クラス一覧:imagenet1000_clsidx_to_labels.txt

git clone https://gist.github.com/942d3a0ac09ec9e5eb3a.git

ダウンロードするとimagenet1000_clsidx_to_labels.txtという名前のファイルができるのでこれを読み込みます。以下のコードで、テキストが読み込まれclassnameという辞書型に格納されます。

with open("imagenet1000_clsidx_to_labels.txt", "r") as f:
  dat = f.read()

classname = eval(dat)

画像の読み込み

適当な画像を用意します。

今回は以下の画像(cat.jpg)を用意しました。

cat.jpg
img = Image.open("cat.jpg")

モデルの読み込み

モデルを読み込んで推論モードにしておきます。今回はtimmresnet18の学習済みモデルを利用します。ローカルにモデルがない場合は、ダウンロードが開始し学習済みモデルがダウンロードされます。

model = timm.create_model("resnet18", pretrained=True)
model.eval()

通常の推論

以下は通常の推論コードです。imagenetの画像の学習では、決められたパラメータで正規化されているので、正規化を行うようにpreprocessを設定します。

modelの出力は1000クラスに対する重みが出力されるので、argmaxを利用して一番重みの大きなクラスを計算しています。

また、probは、softmaxした結果を取り出しています。

さらに、確率の高い10個を取り出して表示しています。

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

tensor = preprocess(img)
plt.imshow(tensor.permute(1,2,0))

with torch.no_grad():
  pred = model(tensor.unsqueeze(0))
class_id = pred.argmax(axis=1).item()
prob = pred.softmax(axis=1)[0, class_id]
print(f"id:{class_id} name:{classname[class_id]} probability:{prob}")
p = pred.softmax(axis=1)[0]
print(p.topk(k=10))

これをみると約23.8%の確率で”tiger cat”と予測していることがわかります。

id:282 name:tiger cat probability:0.23809778690338135
torch.return_types.topk(
values=tensor([0.2381, 0.0985, 0.0605, 0.0203, 0.0099, 0.0092, 0.0058, 0.0053, 0.0052,
        0.0049]),
indices=tensor([282, 285, 281, 287, 284, 292, 750, 356, 904, 543]))
正規化された画像

TTAの実装方法

データ拡張した画像を準備

今回は、データ拡張としてリサイズ(Resize)、切り出し(CenterCrop)、水平反転(HorizontalFlip)を行なっています。なお、水平反転にRandomHorizontalFlipを使っていますが確率p=1.0としているので必ず反転されます。

preprocess2の内容がpreprocessと異なるだけで、あとは同じです。

preprocess2 = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.RandomHorizontalFlip(1.0),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

tensor2 = preprocess2(img)
plt.imshow(tensor2.permute(1,2,0))

with torch.no_grad():
  pred = model(tensor2.unsqueeze(0))
class_id = pred.argmax(axis=1).item()
prob = pred.softmax(axis=1)[0, class_id]
print(f"id:{class_id} name:{classname[class_id]} probability:{prob}")
p = pred.softmax(axis=1)[0]
print(p.topk(k=10))

データ拡張した画像では、約69.2%の確率で”tiger cat”と予測しておりデータ拡張前と同じ予測結果ですが、確率は大きくなりました。

確率が高くのは、たまたまです。たまたま画像をデータ拡張した画像の方がよかっただけです。

id:282 name:tiger cat probability:0.6920549869537354
torch.return_types.topk(
values=tensor([0.6921, 0.1210, 0.0708, 0.0144, 0.0030, 0.0021, 0.0020, 0.0020, 0.0019,
        0.0016]),
indices=tensor([282, 285, 281, 287, 761, 620, 478, 611, 904, 508]))
データ拡張した画像

TTAを実装

以下、TTAのコードになります。tensortensor2の2つの画像に対して推論を行い、結果を合計しています。

確率なので本来は平均を取るべきだと思います。ただ、argmaxの場合は最も大きなものを選ぶだけなので平均化する必要はないので省きました。

softmaxを取る場合も平均を取らなくても問題ありません。

inputs = [tensor, tensor2]

preds = None
for input in inputs :
  with torch.no_grad():
    pred = model(tensor.unsqueeze(0))
    if preds is None:
      preds = pred
    else:
      preds += pred

class_id = preds.argmax(axis=1).item()
prob = preds.softmax(axis=1)[0, class_id]
print(f"id:{class_id} name:{classname[class_id]} probability:{prob}")
p = pred.softmax(axis=1)[0]
print(p.topk(k=10))

ポイントは、predsに結果を足し込んでいる点です。predは1000クラスのどれかという重みですが、これを足しこむことで全ての画像の結果が集計されることになります。

推論結果はこれまでと同じですが、確率が約79.3%となり、より自信を持って”tiger cat”を選んでいることがわかります。

id:282 name:tiger cat probability:0.7927638292312622
torch.return_types.topk(
values=tensor([0.2381, 0.0985, 0.0605, 0.0203, 0.0099, 0.0092, 0.0058, 0.0053, 0.0052,
        0.0049]),
indices=tensor([282, 285, 281, 287, 284, 292, 750, 356, 904, 543]))

この例では、かなり精度がアップしましたが、逆にダウンすることもあります。結果を見ながら使うかどうかは考えましょう。

また、データ拡張の手法の選択も重要です。

まとめ

今回は、TTAをすることで精度が向上する例を紹介しました。TTAは、かなり強力なツールです。うまく使うことで精度向上が期待できます。

ただ、「データ拡張にどれを選ぶか」、「結果を合成する方法をどうするか」など試行錯誤する部分多いです。とりあえず、精度向上のテクニックとして覚えておいて損はないです。

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました