機械学習
記事内に商品プロモーションを含む場合があります

超高速なゼロショットセグメンテーションモデル「FastSAM」の使い方

Aru

MetaのMetaのSegmentation Anything Model(SAM)はゼロショットセグメンテーションで大きな話題を集めましたが、FastSAMはその高速化バージョンです。この記事では、FastSAMの使い方とMetaのSAMとの比較を行った結果を解説します。

MetaのSAMの記事はこちら
Segmentation Anything(SAM)によるゼロショットセグメンテーション|使い方
Segmentation Anything(SAM)によるゼロショットセグメンテーション|使い方

FAST-SAMとは

FAST Segmentation Anything Model(FastSAM)は、SAMのデータセットのわずか2%で訓練されたCNNベースのゼロショットのセグメンテーションモデルです。

作者によると、SAMの50倍高速ということで、高速性が売りのSAMになります。

オリジナルのSAMがViTをバックボーンに使っていたのに対して、CNNを使うことで高速化を実現しています。ちなみに、CNNの部分はYOLOv8をベースにしているみたいです。

この記事では、FastSAMを実際に動かして、その実力をチェックしてみました。

以下の記事はSAMを動かした時の記事です。この記事では、同様のセグメンテーションを行なって、SAM同士の比較も行いました。

Segmentation Anything(SAM)によるゼロショットセグメンテーション|使い方
Segmentation Anything(SAM)によるゼロショットセグメンテーション|使い方

FastSAM(Github): https://github.com/CASIA-IVA-Lab/FastSAM

この記事のコードは、以下にあります

Google Colabで動作するコードをこちら(Github)に用意しました

公式のGoogle Colabでの動作テストコードは、コマンドラインツールを呼び出すタイプでしたが、この記事ではPythonで呼び出して使っています。

Pythonで記述するにあたり、ドキュメントが少ないのが目立ちました。この手のライブラリに慣れていないと、使い方を調べるので苦労するかもしれません。とりあえずの使い方は、この記事が参考になると思います

前準備

この記事のコードについて

この記事のコードは、Google Colabで動かすこをと前提としています。一部、Colab独自の手順やパッケージなどがありますので注意してください。

FastSAMのインストール

tensorflow-probabilityをアンインストールする

Google Colabの場合、tensorflow-probabilityのバージョンでエラーが出てインストールが失敗しました。とりあえず、アンインストールして回避しています(2023.12.1)

最初に、tensorflow-probabilityをアンインストールしておきます。

!pip uninstall --yes tensorflow-probability

FastSAMのインストール

FastSAMをインストールするには、以下のコマンドを実行します。

!git clone https://github.com/CASIA-IVA-Lab/FastSAM.git
%cd FastSAM
!pip install -r requirements.txt

Clipのインストール

FastSAMの動作に必要となるCLIPをインストールします。

!pip install git+https://github.com/openai/CLIP.git

以上で、パッケージのインストールは完了です。

途中、ERROR, WARNINGが出るかもしれません

ライブラリのインポート

今回の実験で使うライブラリをインポートします。

from google.colab.patches import cv2_imshowは、Google Colab用のcv2.imshow()の代替関数です。Colabでは、cv2.imshow()の代わりにこちらを使います。

import cv2
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt

from fastsam import FastSAM, FastSAMPrompt
import torch
import numpy as np

モデルのダウンロード

モデルをダウンロードします。モデルはGoogle Drive上で共有されています。Google Driveからのダウンロードは、「curlやwgetで公開済みGoogle Driveデータをダウンロードする」を参考にしています。

# https://qiita.com/namakemono/items/c963e75e0af3f7eed732
!curl -sc /tmp/cookie "https://drive.google.com/uc?export=download&id=1m1sjY4ihXBU1fZXdQ-Xdj-mDltW-2Rqv" > /dev/null
!CODE="$(awk '/_warning_/ {print $NF}' /tmp/cookie)"  ; curl -Lb /tmp/cookie "https://drive.google.com/uc?export=download&confirm=${CODE}&id=1m1sjY4ihXBU1fZXdQ-Xdj-mDltW-2Rqv" -o FastSAM-x.pt

画像の読み込み

画像は、こちらに用意しました。Colabのコードセルでwgetを実行してダウンロードします。

!wget -O cat2.png https://github.com/aruaru0/SAM-TEST/blob/main/cat2.png?raw=true

ダウンロードした画像を読み込みます。

image = cv2.imread('cat2.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(5,5))
plt.imshow(image)
plt.axis('on')
plt.show()

テスト画像は、SAMで使用したものと同じです。同じようにセグメンテーションを行い、両者の違いも確認していきます。

テスト画像

モデルを読み込む

モデルを読み込みます。以下では、prompt_processを利用して処理を行っていきます。

model = FastSAM('./FastSAM-x.pt')
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
everything_results = model(image, device=device, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9,)
prompt_process = FastSAMPrompt(image, everything_results, device=device)

画像全体に自動セグメンテーション

画像全体の自動セグメンテーションは、everything_promptを利用します。

ann = prompt_process.everything_prompt()

print(ann.shape)

prompt_process.plot(annotations=ann,output_path='./result.jpg',)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

出力結果は以下になります。出力の感じから、おそらくYOLOv8のセグメンテーションの機能を利用していると思われます。また、ann.shape[20,768,512]となっていることから全部で20個のセグメントが出力されていることも確認できます。

torch.Size([20, 768, 512])
全体のセグメンテーション

YOLOv8のセグメンテーションについては以下の記事を参考にしてください

YOLOv8でセグメンテーション|カスタムデータの学習と推論
YOLOv8でセグメンテーション|カスタムデータの学習と推論

場所を指定してセグメンテーション

SAMと同様に、点、領域を指定してセグメンテーションを行なってみます

単一の点を指定してセグメンテーション

画像の1点を指定して、その点を含むセグメントを検出させた例です。

点を指定したセグメンテーションはpoint_promptを利用します。引数のpointsは、画像中の点の座標、pointlabelは0または1を指定します。1の場合は「含める」指定、0の場合は「除外する」指定になるようです。

prompt_process.plotの引数pointspoint_labelに、点の座標とラベルを渡すと、画像中にプロットしてくれます。

input_point = np.array([[200, 200]])
input_label = np.array([1])

ann = prompt_process.point_prompt(points=input_point,  pointlabel=input_label)

prompt_process.plot(annotations=ann,output_path='./result.jpg',
                    points = input_point,
                    point_label = input_label,)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

結果は以下になります。シアン色の点が指定した点、赤い枠の中が点から抽出されたセグメントになります。結果を見ると、猫が選択されたようです。

1点を指定してセグメンテーション

複数点を指定してセグメンテーション

点の指定は複数行うこともできます。

以下の例では、2点を指定しています。

input_point = np.array([[200, 200], [400, 650]])
input_label = np.array([1, 1])

ann = prompt_process.point_prompt(points=input_point,  pointlabel=input_label)

prompt_process.plot(annotations=ann,output_path='./result.jpg',
                    points = input_point,
                    point_label = input_label,)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

猫と本を指定しましたが、ちゃんとセグメントを抽出しているようです。

2点を指定してセグメンテーション

除外する点を指定してセグメンテーション

SAMと同様に除外したい点も指定することができます。

除外したい場所は、pointlabelに0を設定します

input_point = np.array([[200, 200], [400, 650], [200, 600]])
input_label = np.array([1, 1, 0])


ann = prompt_process.point_prompt(points=input_point,  pointlabel=input_label)

prompt_process.plot(annotations=ann,output_path='./result.jpg',
                                        points = input_point,
                    point_label = input_label,)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

シアン色の点がセグメントとして取り出す部分、マゼンダ色の点が除外する点です。

手の部分だけ除外しようとしましたが、猫全体が除外されてしまいました。

MetaのSAMでは、同じような指定でうまくいきましたが、FastSAMではうまくいきませんでした。このあたりがモデルの差のようです。

高速な分、少し精度が悪いという印象です

除外点を加えてセグメンテーション

矩形領域を指定してセグメンテーション

矩形の領域を指定して、セグメンテーションすることも可能です。

矩形領域でセグメンテーションする場合は、box_promptを使います。

input_boxes = [
    [0, 250, 150, 600],
]

ann = prompt_process.box_prompt(bboxes = input_boxes)

print(ann.shape)

prompt_process.plot(annotations=ann,output_path='./result.jpg',
                    bboxes = input_boxes,)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

赤い矩形が指定した領域です。この例では、椅子をちゃんとセグメンテーションすることができました。

(1, 768, 512)
矩形領域を指定してセグメンテーション

複数の領域を指定してセグメンテーション

複数の領域を設定して、セグメンテーションを行うことも可能です。

input_boxes = [
    [0, 250, 150, 600],
    [50, 50, 400, 650],
    [80, 550, 512, 768]
]

ann = prompt_process.box_prompt(bboxes = input_boxes)

print(ann.shape)

prompt_process.plot(annotations=ann,output_path='./result.jpg',
                    bboxes = input_boxes,)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

椅子、猫、本の3つの領域を指定してセグメンテーションさせましたが、それぞれうまく抜き出せています。なお、領域を3つ指定したので、マスクも3つ出力されています

椅子、猫、本は別のマスクとして出力されています

(3, 768, 512)

テキストを指定してセグメンテーション

サンプルを見ると、テキストでもセグメンテーションできるみたいです。

テキストでのセグメンテーションには、text_promptを利用します。

試しにcatで抽出してみました

ann = prompt_process.text_prompt(text='cat')

print(ann.shape)

prompt_process.plot(annotations=ann,output_path='./result.jpg',)

result = cv2.imread('./result.jpg')
cv2_imshow(result)

結果は以下になりますが、想定とは違った部分が選択されていました。テキストでの指定がどの程度うまくいくのかは今回の実験では分かりませんでした。

(1, 768, 512)

MetaのSAMとの違い(使ってみた感想)

あくまで両者を使ってみた感想になりますが、MetaのSAMと比較すると少し使いにくい印象を持ちました。また、高速性は魅力ですが、除外・追加の指定に対しての動きはMetaのSAMの方が自然な印象を受けました。

高速性と精度を天秤にかけて、どちらを使うか考える必要がありそうです。個人的には、とりあえずFastSAMでやってみみて、厳しそうならMetaのSAMを使うかなと思いました。

ほぼ同じセグメンテーションをMetaのSAMでもやっていますので、そちらもみながら比較してみてください。

MetaのSAMの記事はこちら
Segmentation Anything(SAM)によるゼロショットセグメンテーション|使い方
Segmentation Anything(SAM)によるゼロショットセグメンテーション|使い方

まとめ

以上、ゼロショットのセグメンテーションを実現するSAMの高速版FastSAMを実際に試した結果を紹介しました。

SAMに比べて多少使いにくい感じはありますが、50倍高速というのは魅力です。実際につかってみて、FastSAMだけでOKということはなく、アプリケーションによってSAMと使い分けが必要かなと感じました。

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました