PyTorchの可視化機能を紹介|物体検出枠とセグメンテーション結果を可視化
PyTorchに、物体検出やセグメンテーションの結果の可視化関数が用意されているのをご存知でしょうか?この記事では、torchvisionのutilsに用意されている可視化関数について説明します。
YOLOについては、別の可視化ツールを使う方が楽です。
torchvision
torchvisionは、PyTorchのパッケージの一部です。torchvisionには、画像関連(コンピュータビジョン)のデータセットやモデル、画像変換処理が含まれています。
このtorchvisionのutils
というモジュールの中に、可視化用の関数がいくつか用意されています。
意外と便利なので、この記事で紹介します。
utils
の存在に気づかず、都度自分で作っていましたが、これを使う方が楽です
ヘルパー関数を準備
説明の前にヘルパー関数を準備しておきます。
show
関数は、torchvision.utilsの公式にあるものです。使いやすかったので、結果の表示にはこの関数を利用することにします。
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
plt.rcParams["savefig.bbox"] = 'tight'
def show(imgs):
if not isinstance(imgs, list):
imgs = [imgs]
fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = img.detach()
img = F.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
画像の読み込み
ここでは、cat1.png
とcat2.png
という2つの画像を準備して読み込んでいます。
read_image
関数もtorchvisionに準備されている関数です
画像は適当に準備してください。
from torchvision.io import read_image
cat1 = read_image('cat1.png')
cat2 = read_image('cat2.png')
make_grid
画像をグリッド状に並べた画像を生成する関数です。
入力は、画像のリストか、ミニバッチ形式(B×C×H×W)のデータになります。
(B×C×H×W)がそのまま入力できるので、modelの出力をそのままグリッドに並べることができます。便利!
以下の例は、画像2つを並べる例です。画像を生成し、ヘルパー関数(show
)で表示しています。
from torchvision.utils import make_grid
cats = [cat1, cat2]
grid = make_grid(cats)
show(grid)
次の例は、4つの画像を3列で並べる例です。画像が4枚しかないので、2箇所が空白になります。
cats = [cat1, cat2, cat1, cat2]
grid = make_grid(cats, nrow=3)
show(grid)
make_gridを使うと、画像をグリッド状に並べることができます。出力を確認する場合などに便利です。
物体検出のBBOXの描画
utils
には、検出枠(バウンディングボックス、BBOX)を描画する機能もあります。
ここでは、torchvisionの物体検出モデル(fasterrcnn_resnet50_fpn
)を使って、物体検出を行い、その結果を画像に重畳してみます。
以下は、物体検出するまでのコードです。
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms.functional import convert_image_dtype
batch_int = torch.stack([cat1, cat2])
batch = convert_image_dtype(batch_int, dtype=torch.float)
model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False)
model = model.eval()
outputs = model(batch)
outputs[0]
出力は以下のようになります。bboxes
が検出した枠の座標、labels
が検出した枠のクラス、scores
が検出した枠の信頼度になります。今回は、学習済みモデルを利用しましたが、猫は学習されているので、検出されるはずです。
{'boxes': tensor([[ 49.4662, 64.3665, 425.9760, 701.5709],
[ 0.0000, 301.2425, 512.0000, 726.1497],
[298.9560, 18.5366, 509.6454, 422.4392],
[371.1083, 192.5258, 502.9293, 400.5536],
[ 0.0000, 564.7842, 509.1015, 764.2759],
[ 53.3866, 28.6720, 457.0863, 620.9099],
[418.0938, 402.1983, 512.0000, 493.2810],
[ 0.0000, 275.1741, 478.8131, 723.5930],
[403.0338, 406.7412, 512.0000, 551.4507],
[395.6396, 266.7185, 485.6213, 402.8719],
[ 11.5985, 83.1435, 336.1311, 460.2845]], grad_fn=<StackBackward0>),
'labels': tensor([17, 15, 64, 64, 15, 64, 67, 62, 15, 86, 17]),
'scores': tensor([0.9889, 0.5225, 0.4071, 0.3121, 0.1079, 0.0920, 0.0879, 0.0828, 0.0782,
0.0743, 0.0509], grad_fn=<IndexBackward0>)}
utils
のdraw_bounding_boxes
を使って描画するコードです。閾値を超える信頼度の枠だけを引数boxes
に渡しています。width
は、枠の線幅です。
from torchvision.utils import draw_bounding_boxes
score_threshold = 0.9
boxes = [
draw_bounding_boxes(img, boxes=output['boxes'][output['scores'] > score_threshold], width=4)
for img, output in zip(batch_int, outputs)
]
show(boxes)
結果は以下になります。結果を見ると猫と、本と、椅子を検出しているようです。
セグメンテーションの描画
utils
には、セグメンテーションの結果を表示する関数も準備されています。
ここでは、torchvisionのセグメンテーションモデル(fcn_resnet50
)を使って、セグメンテーションを行い、その結果を画像に重畳してみます。
以下は、セグメンテーションするコードです
物体検出・セグメンテーションともにtorchvisionを使えば簡単です
学習済みモデルを使うだけなら、ほんと簡単ですよね
from torchvision.models.segmentation import fcn_resnet50
model = fcn_resnet50(pretrained=True, progress=False)
model = model.eval()
normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
output = model(normalized_batch)['out']
output[0].shape
結果は、画像サイズと同じで、21チャネルのデータになります。
torch.Size([21, 768, 512])
21チャネルは、21のクラスに対応していて、それぞれ以下のコードのsem_classes
になります。今回は、この中からcat
の部分だけ取り出して、重畳してみます。
cat
だけを取り出すコードは以下になります。やっているのは、正規化したマスクデータの最大のチャネルの番号を取り出して、それが猫のラベルと一致している部分だけを取り出します。
マスクを正規化するためにsoftmaxをしていますが、argmaxするだけならここは必要ないかも
また、show
で描画するためにTrue/Falseをfloatに変換しています
sem_classes = [
'__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
normalized_masks = torch.nn.functional.softmax(output, dim=1)
class_dim = 1
bool_cat = (normalized_masks.argmax(class_dim) == sem_class_to_idx['cat'])
show([m.float() for m in bool_cat])
出来上がったマスクは以下になります。
これを、utils
のdraw_segmentation_masks
を使って画像に重畳します。
from torchvision.utils import draw_segmentation_masks
cats_masks = [
draw_segmentation_masks(img, masks=mask, alpha=0.7, colors="red")
for img, mask in zip(batch_int, bool_cat)
]
show(cats_masks)
まとめ
以上、PyTorchのtorchvisionを使った可視化処理について説明しました。
utils
に用意されているのを知らずに自作していましたが、用意されている関数を使った方が楽なので、そちらを利用した方が良いと思います。
YOLOの検出結果についてはsupervison
を使った方が楽です。supervisionについては以下の記事にあります。