機械学習
記事内に商品プロモーションを含む場合があります

FastSAMによるゼロショットセグメンテーション|使い方を解説

tadanori

Zero-shotのセグメンテーションを実現するMetaのSegmentation Anything Model(SAM)より高速なセグメンテーションモデルです。FastSAMも使ってみたので、SAMとの違いも含めて解説します。

FAST-SAMとは

FAST Segmentation Anything Model(FastSAM)は、SAMのデータセットのわずか2%で訓練されたCNNベースのゼロショットのセグメンテーションモデルです。

作者によると、SAMの50倍高速ということです。

オリジナルのSAMがViTをバックボーンに使っていたのに対して、CNNを使うことで高速化を実現しています。ちなみに、CNNの部分はYOLOv8をベースにしているみたいです。

今回は、FastSAMを実際にうとかして、実力をチェックしてみました。

SAMも動かしてみているのでそちらとの比較もしていきます。SAMの記事は以下になります。

SAMによるゼロショットセグメンテーション|使い方を解説
SAMによるゼロショットセグメンテーション|使い方を解説

FastSAM(Github): https://github.com/CASIA-IVA-Lab/FastSAM

この記事のコードは、以下にあります

Google Colabで動作するコードをこちら(Github)に用意しました

公式のGoogle Colabでの動作テストコードが、コマンドラインツールを呼び出すタイプでしたので、Pythonから呼び出す方法を調べながら作成しています。

ドキュメントが少ないので、もっと良い使い方があるかもしれません

前準備

この記事のコードについて

この記事のコードは、Google Colabで動かすこをと前提としています。一部、Colab独自の手順やパッケージなどがありますので注意してください。

FastSAMのインストール

tensorflow-probabilityをアンインストールする

Google Colabの場合、tensorflow-probabilityのバージョンでエラーが出てインストールが失敗しました。とりあえず、アンインストールして回避しています(2023.12.1)

最初に、tensorflow-probabilityをアンインストールしておきます。

!pip uninstall --yes tensorflow-probability

FastSAMのインストール

FastSAMをインストールするには、以下のコマンドを実行します。

!git clone https://github.com/CASIA-IVA-Lab/FastSAM.git
%cd FastSAM
!pip install -r requirements.txt

Clipのインストール

FastSAMの動作に必要となるCLIPをインストールします。

!pip install git+https://github.com/openai/CLIP.git

以上で、パッケージのインストールは完了です。

途中、ERROR, WARNINGが出るかもしれません

ライブラリのインポート

今回の実験で使うライブラリをインポートします。

from google.colab.patches import cv2_imshowは、Google Colab用のcv2.imshow()の代替関数です。Colabでは、cv2.imshow()の代わりにこちらを使います。

import cv2
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt

from fastsam import FastSAM, FastSAMPrompt
import torch
import numpy as np

モデルのダウンロード

モデルをダウンロードします。モデルはGoogle Drive上で共有されています。Google Driveからのダウンロードは、「curlやwgetで公開済みGoogle Driveデータをダウンロードする」を参考にしています。

# https://qiita.com/namakemono/items/c963e75e0af3f7eed732
!curl -sc /tmp/cookie "https://drive.google.com/uc?export=download&id=1m1sjY4ihXBU1fZXdQ-Xdj-mDltW-2Rqv" > /dev/null
!CODE="$(awk '/_warning_/ {print $NF}' /tmp/cookie)"  ; curl -Lb /tmp/cookie "https://drive.google.com/uc?export=download&confirm=${CODE}&id=1m1sjY4ihXBU1fZXdQ-Xdj-mDltW-2Rqv" -o FastSAM-x.pt

画像の読み込み

画像は、こちらに用意しました。Colabのコードセルでwgetを実行してダウンロードします。

!wget -O cat2.png https://github.com/aruaru0/SAM-TEST/blob/main/cat2.png?raw=true

ダウンロードした画像を読み込みます。

image = cv2.imread('cat2.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(5,5))
plt.imshow(image)
plt.axis('on')
plt.show()

テスト画像は、SAMで使用したものと同じです

テスト画像

モデルを読み込む

モデルを読み込みます。以下では、prompt_processを利用して処理を行っていきます。

model = FastSAM('./FastSAM-x.pt')
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
everything_results = model(image, device=device, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9,)
prompt_process = FastSAMPrompt(image, everything_results, device=device)

画像全体に自動セグメンテーション

画像全体の自動セグメンテーションは、everything_promptを利用します。

ann = prompt_process.everything_prompt()

print(ann.shape)

prompt_process.plot(annotations=ann,output_path='./result.jpg',)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

出力結果は以下になります。出力の感じから、おそらくYOLOv8のセグメンテーションの機能を利用していると思われます。また、ann.shape[20,768,512]となっていることから全部で20個のセグメントが出力されていることも確認できます。

torch.Size([20, 768, 512])
全体のセグメンテーション

YOLOv8のセグメンテーションについては以下の記事を参考にしてください

YOLOv8でセグメンテーション|学習と推論を実践
YOLOv8でセグメンテーション|学習と推論を実践

場所を指定してセグメンテーション

SAMと同様に、点、領域を指定してセグメンテーションを行なってみます

単一の点を指定してセグメンテーション

画像の1点を指定して、その点を含むセグメントを検出させた例です。

点を指定したセグメンテーションはpoint_promptを利用します。引数のpointsは、画像中の点の座標、pointlabelは0または1を指定します。1の場合は「含める」指定、0の場合は「除外する」指定になるようです。

prompt_process.plotの引数pointspoint_labelに、点の座標とラベルを渡すと、画像中にプロットしてくれます。

input_point = np.array([[200, 200]])
input_label = np.array([1])

ann = prompt_process.point_prompt(points=input_point,  pointlabel=input_label)

prompt_process.plot(annotations=ann,output_path='./result.jpg',
                    points = input_point,
                    point_label = input_label,)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

結果は以下になります。シアン色の点が指定した点、赤い枠の中が点から抽出されたセグメントになります。結果を見ると、猫が選択されたようです。

1点を指定してセグメンテーション

複数点を指定してセグメンテーション

点の指定は複数行うこともできます。

以下の例では、2点を指定しています。

input_point = np.array([[200, 200], [400, 650]])
input_label = np.array([1, 1])

ann = prompt_process.point_prompt(points=input_point,  pointlabel=input_label)

prompt_process.plot(annotations=ann,output_path='./result.jpg',
                    points = input_point,
                    point_label = input_label,)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

猫と本を指定しましたが、ちゃんとセグメントを抽出しているようです。

2点を指定してセグメンテーション

除外する点を指定してセグメンテーション

SAMと同様に除外したい点も指定することができます。

除外したい場所は、pointlabelに0を設定します

input_point = np.array([[200, 200], [400, 650], [200, 600]])
input_label = np.array([1, 1, 0])


ann = prompt_process.point_prompt(points=input_point,  pointlabel=input_label)

prompt_process.plot(annotations=ann,output_path='./result.jpg',
                                        points = input_point,
                    point_label = input_label,)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

シアン色の点がセグメントとして取り出す部分、マゼンダ色の点が除外する点です。

手の部分だけ除外しようとしましたが、猫全体が除外されてしまったようです。

SAMでは多少、境界が変でしたが手が除外されていました。FastSAMではうまくいきませんでした。

除外点を加えてセグメンテーション

矩形領域を指定してセグメンテーション

矩形の領域を指定して、セグメンテーションすることも可能です。

矩形領域でセグメンテーションする場合は、box_promptを使います。

input_boxes = [
    [0, 250, 150, 600],
]

ann = prompt_process.box_prompt(bboxes = input_boxes)

print(ann.shape)

prompt_process.plot(annotations=ann,output_path='./result.jpg',
                    bboxes = input_boxes,)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

赤い矩形が指定した領域です。この例では、椅子をちゃんとセグメンテーションすることができました。

(1, 768, 512)
矩形領域を指定してセグメンテーション

複数の領域を指定してセグメンテーション

複数の領域を設定して、セグメンテーションを行うことも可能です。

input_boxes = [
    [0, 250, 150, 600],
    [50, 50, 400, 650],
    [80, 550, 512, 768]
]

ann = prompt_process.box_prompt(bboxes = input_boxes)

print(ann.shape)

prompt_process.plot(annotations=ann,output_path='./result.jpg',
                    bboxes = input_boxes,)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

椅子、猫、本の3つの領域を指定してセグメンテーションさせましたが、それぞれうまく抜き出せています。なお、領域を3つ指定したので、マスクも3つ出力されています

椅子、猫、本は別のマスクとして出力されています

(3, 768, 512)

テキストを指定してセグメンテーション

サンプルを見ると、テキストでもセグメンテーションできるみたいです。

テキストでのセグメンテーションには、text_promptを利用します。

試しにcatで抽出してみました

ann = prompt_process.text_prompt(text='cat')

print(ann.shape)

prompt_process.plot(annotations=ann,output_path='./result.jpg',)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

結果は以下になりますが、想定とは違った部分が選択されていました。テキストでの指定がどの程度うまくいくのかは今回の実験では分かりませんでした。

(1, 768, 512)

まとめ

以上、ゼロショットのセグメンテーションを実現するSAMの高速版FastSAMを実際に試した結果を紹介しました。

SAMに比べて多少使いにくい感じはありますが、50倍高速というのは魅力です。実際につかってみて、FastSAMだけでOKということはなく、アプリケーションによってSAMと使い分けが必要かなと感じました。

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました