TTA(Test Time Augmentation)って何?とりあえず実装してみる
TTA(Test Time Augmentation)は、ディープラーニングで精度向上のために利用させるテクニックの1つです。ここでは、実際にTTAをPytorchで実践してみます。
TTA(Test Time Augmentation)とは
TTA(Test Time Augmentation)は、ディープラーニングや機械学習の分野で使用されるテクニックの1つです。
一般的に、データ拡張(Data Augmentation)は、トレーニング時に行われます。
TTAは、データ拡張を推論時に行う点がポイントになります。
TTAでは、推論する画像に対して回転・拡大縮小・水平反転などの変換を行い、変換したデータに対してもモデルによる推論を行います。
最終的には、変換した画像の推論結果の合計や平均、最大値、多数決などを計算し、最終的な予測結果とするものです。
このTTA、どのような場合に利点があるかというと、精度向上が期待できることです。
TTAは、ロバスト性を高めて、性能を向上させることができるのですが、推論時に処理が増加してしまうという欠点もあります。ただ、処理時間が許せばTTAにより精度向上が期待できるわけですからやらない理由もありません。
特に、kaggleのコンペのように「あと少しだけ性能を上げたい」という場合には選択肢の1つとなります。
ここでは、TIMM(PyTorch Image Models)の学習済みモデルを使って、TTAを実際に実装してみます。結構簡単に実装できますので、気軽に試すことが可能です。
ちなみに、TTAではなく、複数の学習済みモデルを合成するModel soupsという技法も精度向上のために利用されます。こちらは、速度低下なく精度向上を行うことが可能です。
実装例
以下、resnet18の学習済みモデルを使ってTTAを実践してみます。
ライブラリのインポート
必要なライブラリをインポートします。
import torch
import timm
from PIL import Image
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
imagenetのクラス名を読み込む
結果をわかりやすくするために、imagenetのクラス名の一覧を読み込みます。一覧は、以下のURLにあるのでダウンロードするか、git cloneを使ってコピーしてください。
クラス一覧:imagenet1000_clsidx_to_labels.txt
git clone https://gist.github.com/942d3a0ac09ec9e5eb3a.git
ダウンロードするとimagenet1000_clsidx_to_labels.txt
という名前のファイルができるのでこれを読み込みます。以下のコードで、テキストが読み込まれclassname
という辞書型に格納されます。
with open("imagenet1000_clsidx_to_labels.txt", "r") as f:
dat = f.read()
classname = eval(dat)
画像の読み込み
適当な画像を用意します。
今回は以下の画像(cat.jpg)を用意しました。
img = Image.open("cat.jpg")
モデルの読み込み
モデルを読み込んで推論モードにしておきます。今回はtimm
のresnet18
の学習済みモデルを利用します。ローカルにモデルがない場合は、ダウンロードが開始し学習済みモデルがダウンロードされます。
model = timm.create_model("resnet18", pretrained=True)
model.eval()
通常の推論
以下は通常の推論コードです。imagenetの画像の学習では、決められたパラメータで正規化されているので、正規化を行うようにpreprocess
を設定します。
model
の出力は1000クラスに対する重みが出力されるので、argmax
を利用して一番重みの大きなクラスを計算しています。
また、prob
は、softmax
した結果を取り出しています。
さらに、確率の高い10個を取り出して表示しています。
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
tensor = preprocess(img)
plt.imshow(tensor.permute(1,2,0))
with torch.no_grad():
pred = model(tensor.unsqueeze(0))
class_id = pred.argmax(axis=1).item()
prob = pred.softmax(axis=1)[0, class_id]
print(f"id:{class_id} name:{classname[class_id]} probability:{prob}")
p = pred.softmax(axis=1)[0]
print(p.topk(k=10))
これをみると約23.8%の確率で”tiger cat”と予測していることがわかります。
id:282 name:tiger cat probability:0.23809778690338135
torch.return_types.topk(
values=tensor([0.2381, 0.0985, 0.0605, 0.0203, 0.0099, 0.0092, 0.0058, 0.0053, 0.0052,
0.0049]),
indices=tensor([282, 285, 281, 287, 284, 292, 750, 356, 904, 543]))
TTAの実装方法
データ拡張した画像を準備
今回は、データ拡張としてリサイズ(Resize
)、切り出し(CenterCrop
)、水平反転(HorizontalFlip
)を行なっています。なお、水平反転にRandomHorizontalFlip
を使っていますが確率p=1.0としているので必ず反転されます。
preprocess2
の内容がpreprocess
と異なるだけで、あとは同じです。
preprocess2 = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.RandomHorizontalFlip(1.0),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
tensor2 = preprocess2(img)
plt.imshow(tensor2.permute(1,2,0))
with torch.no_grad():
pred = model(tensor2.unsqueeze(0))
class_id = pred.argmax(axis=1).item()
prob = pred.softmax(axis=1)[0, class_id]
print(f"id:{class_id} name:{classname[class_id]} probability:{prob}")
p = pred.softmax(axis=1)[0]
print(p.topk(k=10))
データ拡張した画像では、約69.2%の確率で”tiger cat”と予測しておりデータ拡張前と同じ予測結果ですが、確率は大きくなりました。
確率が高くのは、たまたまです。たまたま画像をデータ拡張した画像の方がよかっただけです。
id:282 name:tiger cat probability:0.6920549869537354
torch.return_types.topk(
values=tensor([0.6921, 0.1210, 0.0708, 0.0144, 0.0030, 0.0021, 0.0020, 0.0020, 0.0019,
0.0016]),
indices=tensor([282, 285, 281, 287, 761, 620, 478, 611, 904, 508]))
TTAを実装
以下、TTAのコードになります。tensor
とtensor2
の2つの画像に対して推論を行い、結果を合計しています。
本来は平均を取るべきですが、argmaxをとるだけなので平均を求める必要はないので省いています。また、probもsoftmaxを取るので平均を取らなくても問題ありません。
inputs = [tensor, tensor2]
preds = None
for input in inputs :
with torch.no_grad():
pred = model(tensor.unsqueeze(0))
if preds is None:
preds = pred
else:
preds += pred
class_id = preds.argmax(axis=1).item()
prob = preds.softmax(axis=1)[0, class_id]
print(f"id:{class_id} name:{classname[class_id]} probability:{prob}")
p = pred.softmax(axis=1)[0]
print(p.topk(k=10))
ポイントは、preds
に結果を足し込んでいる点です。pred
は1000クラスのどれかという重みですが、これを足しこむことで全ての画像の結果が集計されることになります。
推論結果はこれまでと同じですが、確率が約79.3%となり、より自信を持って”tiger cat”を選んでいることがわかります。
id:282 name:tiger cat probability:0.7927638292312622
torch.return_types.topk(
values=tensor([0.2381, 0.0985, 0.0605, 0.0203, 0.0099, 0.0092, 0.0058, 0.0053, 0.0052,
0.0049]),
indices=tensor([282, 285, 281, 287, 284, 292, 750, 356, 904, 543]))
まとめ
今回は、TTAをすることで精度が向上する例をサンプルにしました。当然、TTAにより精度が低下することもあります。
ただ、精度が向上する傾向が強いです。
データ拡張の方法もいろいろありますが、データ拡張の方法の組み合わせ方によっても精度が変わります。
試行錯誤する部分が結構ありますが、精度向上のテクニックとして覚えておいて損はないです。