TTA(Test Time Augmentation)とは?PyTorchでの実装方法

TTA(Test Time Augmentation)は、ディープラーニングで推論の精度を向上させるテクニックです。データ拡張(Augmentation)は、通常は学習時に行いますが、TTAでは推論時に行います。この記事では、TTAのメリットとでデメリット、そしてTTAをPytorchで実装する方法について解説します。
TTA(Test Time Augmentation)とは
TTA(Test Time Augmentation)は、ディープラーニングや機械学習の分野で用いられるテクニックの一つです。一般的には、データ拡張(Data Augmentation)はモデルのトレーニング時に行われます。
TTAは推論時にデータ拡張を行う点が特徴です。
具体的には、TTAでは推論する画像に対して回転や拡大縮小、水平反転などの変換を施し、変換後の画像に対してもモデルによる推論を行います。最終的に、得られた推論結果を合計・平均・最大値・多数決などの方法で集計し、最終的な予測結果を導き出します。
TTAの主な利点は、予測精度を向上させることができる点です。
TTAには、ロバスト性を高める一方で、推論時の処理が増えるため、処理時間が長くなるという欠点があります。
しかし、処理時間が許容できる範囲であれば、精度アップが期待できるため、多くのケースで有用です。特に、Kaggleのコンペティションのように「少しでも性能を上げたい!」という場合には、TTAは有力な選択肢となります。
ここでは、PyTorch Image Models(TIMM)の学習済みモデルを使用して、TTAの実装方法を紹介します。
TTAの実装は比較的簡単で、気軽に試すことが可能です。精度アップの手段の1つとして、気軽に試せるのがメリットです。
なお、TTAとは異なるアプローチとして、複数の学習モデルの組み合わせる「Model soups」という手法があります。こちらは、速度低下なしに精度を向上させることが可能な手法です。
Model soupsについては、以下の記事を参考にしてください。

もちろん、TTAとmodel soupsを組み合わせるて、さらなる精度アップを狙うことも可能です。
実装例
以下、resnet18の学習済みモデルを使ってTTAを実践してみます。
ライブラリのインポート
必要なライブラリをインポートします。
import torch
import timm
from PIL import Image
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
imagenetのクラス名を読み込む
結果をわかりやすくするために、imagenetのクラス名の一覧を読み込みます。一覧は、以下のURLにあるのでダウンロードするか、git cloneを使ってコピーしてください。
クラス一覧:imagenet1000_clsidx_to_labels.txt
git clone https://gist.github.com/942d3a0ac09ec9e5eb3a.git
ダウンロードするとimagenet1000_clsidx_to_labels.txt
という名前のファイルができるのでこれを読み込みます。以下のコードで、テキストが読み込まれclassname
という辞書型に格納されます。
with open("imagenet1000_clsidx_to_labels.txt", "r") as f:
dat = f.read()
classname = eval(dat)
画像の読み込み
適当な画像を用意します。
今回は以下の画像(cat.jpg)を用意しました。

img = Image.open("cat.jpg")
モデルの読み込み
モデルを読み込んで推論モードにしておきます。今回はtimm
のresnet18
の学習済みモデルを利用します。ローカルにモデルがない場合は、ダウンロードが開始し学習済みモデルがダウンロードされます。
model = timm.create_model("resnet18", pretrained=True)
model.eval()
通常の推論
以下は通常の推論コードです。imagenetの画像の学習では、決められたパラメータで正規化されているので、正規化を行うようにpreprocess
を設定します。
model
の出力は1000クラスに対する重みが出力されるので、argmax
を利用して一番重みの大きなクラスを計算しています。
また、prob
は、softmax
した結果を取り出しています。
さらに、確率の高い10個を取り出して表示しています。
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
tensor = preprocess(img)
plt.imshow(tensor.permute(1,2,0))
with torch.no_grad():
pred = model(tensor.unsqueeze(0))
class_id = pred.argmax(axis=1).item()
prob = pred.softmax(axis=1)[0, class_id]
print(f"id:{class_id} name:{classname[class_id]} probability:{prob}")
p = pred.softmax(axis=1)[0]
print(p.topk(k=10))
これをみると約23.8%の確率で”tiger cat”と予測していることがわかります。
id:282 name:tiger cat probability:0.23809778690338135
torch.return_types.topk(
values=tensor([0.2381, 0.0985, 0.0605, 0.0203, 0.0099, 0.0092, 0.0058, 0.0053, 0.0052,
0.0049]),
indices=tensor([282, 285, 281, 287, 284, 292, 750, 356, 904, 543]))

TTAの実装方法
データ拡張した画像を準備
今回は、データ拡張としてリサイズ(Resize
)、切り出し(CenterCrop
)、水平反転(HorizontalFlip
)を行なっています。なお、水平反転にRandomHorizontalFlip
を使っていますが確率p=1.0としているので必ず反転されます。
preprocess2
の内容がpreprocess
と異なるだけで、あとは同じです。
preprocess2 = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.RandomHorizontalFlip(1.0),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
tensor2 = preprocess2(img)
plt.imshow(tensor2.permute(1,2,0))
with torch.no_grad():
pred = model(tensor2.unsqueeze(0))
class_id = pred.argmax(axis=1).item()
prob = pred.softmax(axis=1)[0, class_id]
print(f"id:{class_id} name:{classname[class_id]} probability:{prob}")
p = pred.softmax(axis=1)[0]
print(p.topk(k=10))
データ拡張した画像では、約69.2%の確率で”tiger cat”と予測しておりデータ拡張前と同じ予測結果ですが、確率は大きくなりました。

確率が高くのは、たまたまです。たまたま画像をデータ拡張した画像の方がよかっただけです。
id:282 name:tiger cat probability:0.6920549869537354
torch.return_types.topk(
values=tensor([0.6921, 0.1210, 0.0708, 0.0144, 0.0030, 0.0021, 0.0020, 0.0020, 0.0019,
0.0016]),
indices=tensor([282, 285, 281, 287, 761, 620, 478, 611, 904, 508]))

TTAを実装
以下、TTAのコードになります。tensor
とtensor2
の2つの画像に対して推論を行い、結果を合計しています。

確率なので本来は平均を取るべきだと思います。ただ、argmaxの場合は最も大きなものを選ぶだけなので平均化する必要はないので省きました。
softmaxを取る場合も平均を取らなくても問題ありません。
inputs = [tensor, tensor2]
preds = None
for input in inputs :
with torch.no_grad():
pred = model(tensor.unsqueeze(0))
if preds is None:
preds = pred
else:
preds += pred
class_id = preds.argmax(axis=1).item()
prob = preds.softmax(axis=1)[0, class_id]
print(f"id:{class_id} name:{classname[class_id]} probability:{prob}")
p = pred.softmax(axis=1)[0]
print(p.topk(k=10))
ポイントは、preds
に結果を足し込んでいる点です。pred
は1000クラスのどれかという重みですが、これを足しこむことで全ての画像の結果が集計されることになります。
推論結果はこれまでと同じですが、確率が約79.3%となり、より自信を持って”tiger cat”を選んでいることがわかります。
id:282 name:tiger cat probability:0.7927638292312622
torch.return_types.topk(
values=tensor([0.2381, 0.0985, 0.0605, 0.0203, 0.0099, 0.0092, 0.0058, 0.0053, 0.0052,
0.0049]),
indices=tensor([282, 285, 281, 287, 284, 292, 750, 356, 904, 543]))

この例では、かなり精度がアップしましたが、逆にダウンすることもあります。結果を見ながら使うかどうかは考えましょう。
また、データ拡張の手法の選択も重要です。
まとめ
今回は、TTAをすることで精度が向上する例を紹介しました。TTAは、かなり強力なツールです。うまく使うことで精度向上が期待できます。
ただ、「データ拡張にどれを選ぶか」、「結果を合成する方法をどうするか」など試行錯誤する部分多いです。とりあえず、精度向上のテクニックとして覚えておいて損はないです。