機械学習
記事内に商品プロモーションを含む場合があります

TTA(Test Time Augmentation)って何?とりあえず実装してみる

tadanori

TTA(Test Time Augmentation)は、ディープラーニングで精度向上のために利用させるテクニックの1つです。ここでは、実際にTTAをPytorchで実践してみます。

TTA(Test Time Augmentation)とは

TTA(Test Time Augmentation)は、ディープラーニングや機械学習の分野で使用されるテクニックの1つです。

一般的に、データ拡張(Data Augmentation)は、トレーニング時に行われます。

TTAは、データ拡張を推論時に行う点がポイントになります。

TTAでは、推論する画像に対して回転・拡大縮小・水平反転などの変換を行い、変換したデータに対してもモデルによる推論を行います。

最終的には、変換した画像の推論結果の合計や平均、最大値、多数決などを計算し、最終的な予測結果とするものです。

このTTA、どのような場合に利点があるかというと、精度向上が期待できることです。

TTAは、ロバスト性を高めて、性能を向上させることができるのですが、推論時に処理が増加してしまうという欠点もあります。ただ、処理時間が許せばTTAにより精度向上が期待できるわけですからやらない理由もありません。

特に、kaggleのコンペのように「あと少しだけ性能を上げたい」という場合には選択肢の1つとなります。

ここでは、TIMM(PyTorch Image Models)の学習済みモデルを使って、TTAを実際に実装してみます。結構簡単に実装できますので、気軽に試すことが可能です。

ちなみに、TTAではなく、複数の学習済みモデルを合成するModel soupsという技法も精度向上のために利用されます。こちらは、速度低下なく精度向上を行うことが可能です。

Model soupsについてはこちらの記事
【Pytochで実装】Model soups | 重み平均により精度向上させる手法
【Pytochで実装】Model soups | 重み平均により精度向上させる手法

実装例

以下、resnet18の学習済みモデルを使ってTTAを実践してみます。

ライブラリのインポート

必要なライブラリをインポートします。

import torch
import timm
from PIL import Image
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt

Google Colabotoryで実行する場合には、timmをインストールする必要があります。

!pip install timm

imagenetのクラス名を読み込む

結果をわかりやすくするために、imagenetのクラス名の一覧を読み込みます。一覧は、以下のURLにあるのでダウンロードするか、git cloneを使ってコピーしてください。

クラス一覧:imagenet1000_clsidx_to_labels.txt

git clone https://gist.github.com/942d3a0ac09ec9e5eb3a.git

ダウンロードするとimagenet1000_clsidx_to_labels.txtという名前のファイルができるのでこれを読み込みます。以下のコードで、テキストが読み込まれclassnameという辞書型に格納されます。

with open("imagenet1000_clsidx_to_labels.txt", "r") as f:
  dat = f.read()

classname = eval(dat)

画像の読み込み

適当な画像を用意します。

今回は以下の画像(cat.jpg)を用意しました。

cat.jpg
img = Image.open("cat.jpg")

モデルの読み込み

モデルを読み込んで推論モードにしておきます。今回はtimmresnet18の学習済みモデルを利用します。ローカルにモデルがない場合は、ダウンロードが開始し学習済みモデルがダウンロードされます。

model = timm.create_model("resnet18", pretrained=True)
model.eval()

通常の推論

以下は通常の推論コードです。imagenetの画像の学習では、決められたパラメータで正規化されているので、正規化を行うようにpreprocessを設定します。

modelの出力は1000クラスに対する重みが出力されるので、argmaxを利用して一番重みの大きなクラスを計算しています。

また、probは、softmaxした結果を取り出しています。

さらに、確率の高い10個を取り出して表示しています。

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

tensor = preprocess(img)
plt.imshow(tensor.permute(1,2,0))

with torch.no_grad():
  pred = model(tensor.unsqueeze(0))
class_id = pred.argmax(axis=1).item()
prob = pred.softmax(axis=1)[0, class_id]
print(f"id:{class_id} name:{classname[class_id]} probability:{prob}")
p = pred.softmax(axis=1)[0]
print(p.topk(k=10))

これをみると約23.8%の確率で”tiger cat”と予測していることがわかります。

id:282 name:tiger cat probability:0.23809778690338135
torch.return_types.topk(
values=tensor([0.2381, 0.0985, 0.0605, 0.0203, 0.0099, 0.0092, 0.0058, 0.0053, 0.0052,
        0.0049]),
indices=tensor([282, 285, 281, 287, 284, 292, 750, 356, 904, 543]))
正規化された画像

TTAの実装方法

データ拡張した画像を準備

今回は、データ拡張としてリサイズ(Resize)、切り出し(CenterCrop)、水平反転(HorizontalFlip)を行なっています。なお、水平反転にRandomHorizontalFlipを使っていますが確率p=1.0としているので必ず反転されます。

preprocess2の内容がpreprocessと異なるだけで、あとは同じです。

preprocess2 = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.RandomHorizontalFlip(1.0),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

tensor2 = preprocess2(img)
plt.imshow(tensor2.permute(1,2,0))

with torch.no_grad():
  pred = model(tensor2.unsqueeze(0))
class_id = pred.argmax(axis=1).item()
prob = pred.softmax(axis=1)[0, class_id]
print(f"id:{class_id} name:{classname[class_id]} probability:{prob}")
p = pred.softmax(axis=1)[0]
print(p.topk(k=10))

データ拡張した画像では、約69.2%の確率で”tiger cat”と予測しておりデータ拡張前と同じ予測結果ですが、確率は大きくなりました。

確率が高くのは、たまたまです。たまたま画像をデータ拡張した画像の方がよかっただけです。

id:282 name:tiger cat probability:0.6920549869537354
torch.return_types.topk(
values=tensor([0.6921, 0.1210, 0.0708, 0.0144, 0.0030, 0.0021, 0.0020, 0.0020, 0.0019,
        0.0016]),
indices=tensor([282, 285, 281, 287, 761, 620, 478, 611, 904, 508]))
データ拡張した画像

TTAを実装

以下、TTAのコードになります。tensortensor2の2つの画像に対して推論を行い、結果を合計しています。

本来は平均を取るべきですが、argmaxをとるだけなので平均を求める必要はないので省いています。また、probもsoftmaxを取るので平均を取らなくても問題ありません。

inputs = [tensor, tensor2]

preds = None
for input in inputs :
  with torch.no_grad():
    pred = model(tensor.unsqueeze(0))
    if preds is None:
      preds = pred
    else:
      preds += pred

class_id = preds.argmax(axis=1).item()
prob = preds.softmax(axis=1)[0, class_id]
print(f"id:{class_id} name:{classname[class_id]} probability:{prob}")
p = pred.softmax(axis=1)[0]
print(p.topk(k=10))

ポイントは、predsに結果を足し込んでいる点です。predは1000クラスのどれかという重みですが、これを足しこむことで全ての画像の結果が集計されることになります。

推論結果はこれまでと同じですが、確率が約79.3%となり、より自信を持って”tiger cat”を選んでいることがわかります。

id:282 name:tiger cat probability:0.7927638292312622
torch.return_types.topk(
values=tensor([0.2381, 0.0985, 0.0605, 0.0203, 0.0099, 0.0092, 0.0058, 0.0053, 0.0052,
        0.0049]),
indices=tensor([282, 285, 281, 287, 284, 292, 750, 356, 904, 543]))

まとめ

今回は、TTAをすることで精度が向上する例をサンプルにしました。当然、TTAにより精度が低下することもあります。

ただ、精度が向上する傾向が強いです。

データ拡張の方法もいろいろありますが、データ拡張の方法の組み合わせ方によっても精度が変わります。

試行錯誤する部分が結構ありますが、精度向上のテクニックとして覚えておいて損はないです。

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました