機械学習
記事内に商品プロモーションを含む場合があります

PyTorchの可視化機能を紹介|物体検出枠とセグメンテーション結果を可視化

tadanori

PyTorchに、物体検出やセグメンテーションの結果の可視化関数が用意されているのをご存知でしょうか?この記事では、torchvisionのutilsに用意されている可視化関数について説明します。

YOLOについては、別の可視化ツールを使う方が楽です。

YOLOの出力を可視化するツール「supervision」を紹介
YOLOの出力を可視化するツール「supervision」を紹介

torchvision

torchvisionは、PyTorchのパッケージの一部です。torchvisionには、画像関連(コンピュータビジョン)のデータセットやモデル、画像変換処理が含まれています。

このtorchvisionのutilsというモジュールの中に、可視化用の関数がいくつか用意されています。

意外と便利なので、この記事で紹介します。

utilsの存在に気づかず、都度自分で作っていましたが、これを使う方が楽です

ヘルパー関数を準備

説明の前にヘルパー関数を準備しておきます。

show関数は、torchvision.utilsの公式にあるものです。使いやすかったので、結果の表示にはこの関数を利用することにします。

import torch
import numpy as np
import matplotlib.pyplot as plt

import torchvision.transforms.functional as F


plt.rcParams["savefig.bbox"] = 'tight'


def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

画像の読み込み

ここでは、cat1.pngcat2.pngという2つの画像を準備して読み込んでいます。

read_image関数もtorchvisionに準備されている関数です

画像は適当に準備してください。

from torchvision.io import read_image

cat1 = read_image('cat1.png')
cat2 = read_image('cat2.png')

make_grid

画像をグリッド状に並べた画像を生成する関数です。

入力は、画像のリストか、ミニバッチ形式(B×C×H×W)のデータになります。

(B×C×H×W)がそのまま入力できるので、modelの出力をそのままグリッドに並べることができます。便利!

以下の例は、画像2つを並べる例です。画像を生成し、ヘルパー関数(show)で表示しています。

from torchvision.utils import make_grid

cats = [cat1, cat2]

grid = make_grid(cats)
show(grid)

次の例は、4つの画像を3列で並べる例です。画像が4枚しかないので、2箇所が空白になります。

cats = [cat1, cat2, cat1, cat2]

grid = make_grid(cats, nrow=3)

show(grid)

make_gridを使うと、画像をグリッド状に並べることができます。出力を確認する場合などに便利です。

物体検出のBBOXの描画

utilsには、検出枠(バウンディングボックス、BBOX)を描画する機能もあります。

ここでは、torchvisionの物体検出モデル(fasterrcnn_resnet50_fpn)を使って、物体検出を行い、その結果を画像に重畳してみます。

以下は、物体検出するまでのコードです。

from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms.functional import convert_image_dtype

batch_int = torch.stack([cat1, cat2])
batch = convert_image_dtype(batch_int, dtype=torch.float)

model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False)
model = model.eval()

outputs = model(batch)
outputs[0]

出力は以下のようになります。bboxesが検出した枠の座標、labelsが検出した枠のクラス、scoresが検出した枠の信頼度になります。今回は、学習済みモデルを利用しましたが、猫は学習されているので、検出されるはずです。

{'boxes': tensor([[ 49.4662,  64.3665, 425.9760, 701.5709],
         [  0.0000, 301.2425, 512.0000, 726.1497],
         [298.9560,  18.5366, 509.6454, 422.4392],
         [371.1083, 192.5258, 502.9293, 400.5536],
         [  0.0000, 564.7842, 509.1015, 764.2759],
         [ 53.3866,  28.6720, 457.0863, 620.9099],
         [418.0938, 402.1983, 512.0000, 493.2810],
         [  0.0000, 275.1741, 478.8131, 723.5930],
         [403.0338, 406.7412, 512.0000, 551.4507],
         [395.6396, 266.7185, 485.6213, 402.8719],
         [ 11.5985,  83.1435, 336.1311, 460.2845]], grad_fn=<StackBackward0>),
 'labels': tensor([17, 15, 64, 64, 15, 64, 67, 62, 15, 86, 17]),
 'scores': tensor([0.9889, 0.5225, 0.4071, 0.3121, 0.1079, 0.0920, 0.0879, 0.0828, 0.0782,
         0.0743, 0.0509], grad_fn=<IndexBackward0>)}

utilsdraw_bounding_boxesを使って描画するコードです。閾値を超える信頼度の枠だけを引数boxesに渡しています。widthは、枠の線幅です。

from torchvision.utils import draw_bounding_boxes

score_threshold = 0.9
boxes = [
    draw_bounding_boxes(img, boxes=output['boxes'][output['scores'] > score_threshold], width=4)
    for img, output in zip(batch_int, outputs)
]

show(boxes)  

結果は以下になります。結果を見ると猫と、本と、椅子を検出しているようです。

セグメンテーションの描画

utilsには、セグメンテーションの結果を表示する関数も準備されています。

ここでは、torchvisionのセグメンテーションモデル(fcn_resnet50)を使って、セグメンテーションを行い、その結果を画像に重畳してみます。

以下は、セグメンテーションするコードです

物体検出・セグメンテーションともにtorchvisionを使えば簡単です

学習済みモデルを使うだけなら、ほんと簡単ですよね

from torchvision.models.segmentation import fcn_resnet50

model = fcn_resnet50(pretrained=True, progress=False)
model = model.eval()

normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
output = model(normalized_batch)['out']
output[0].shape

結果は、画像サイズと同じで、21チャネルのデータになります。

torch.Size([21, 768, 512])

21チャネルは、21のクラスに対応していて、それぞれ以下のコードのsem_classesになります。今回は、この中からcatの部分だけ取り出して、重畳してみます。

catだけを取り出すコードは以下になります。やっているのは、正規化したマスクデータの最大のチャネルの番号を取り出して、それが猫のラベルと一致している部分だけを取り出します。

マスクを正規化するためにsoftmaxをしていますが、argmaxするだけならここは必要ないかも

また、showで描画するためにTrue/Falseをfloatに変換しています

sem_classes = [
    '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
    'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
    'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
normalized_masks = torch.nn.functional.softmax(output, dim=1)


class_dim = 1
bool_cat = (normalized_masks.argmax(class_dim) == sem_class_to_idx['cat'])
show([m.float() for m in bool_cat])

出来上がったマスクは以下になります。

これを、utilsdraw_segmentation_masksを使って画像に重畳します。

from torchvision.utils import draw_segmentation_masks

cats_masks = [
    draw_segmentation_masks(img, masks=mask, alpha=0.7, colors="red")
    for img, mask in zip(batch_int, bool_cat)
]
show(cats_masks)

まとめ

以上、PyTorchのtorchvisionを使った可視化処理について説明しました。

utilsに用意されているのを知らずに自作していましたが、用意されている関数を使った方が楽なので、そちらを利用した方が良いと思います。

YOLOの検出結果についてはsupervisonを使った方が楽です。supervisionについては以下の記事にあります。

あわせて読みたい
YOLOの出力を可視化するツール「supervision」を紹介
YOLOの出力を可視化するツール「supervision」を紹介

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました