Google ColabでGUI|gradioを利用したウェブアプリの作り方を解説
Google Colabを利用したAI関連のプロジェクトでは、ウェブアプリによるGUI(グラフィカルユーザーインターフェース)が付属することが増えています。この多くが、Gradioというライブラリを利用しています。この記事では、自作したGoogle Colab上のアプリにGUIを加えるための方法として、Gradioの使い方について解説します。
gradioとは
gradioは、機械学習モデルを簡単にWebアプリケーションとして実装・公開できるPythonライブラリです。
gradioの特徴は、以下のとおりです。
- 少ないコードでWebアプリケーションを作成可能
- さまざまな機械学習モデルをサポート
- Google Colab上のアプリでも利用可能
ここでは、gradioをGoogle Colab上で動かしてみます
gradioのドキュメントはhttps://www.gradio.app/にあります
PythonでGUIを作る場合、streamlitを使う方法もあります。streamlitについては以下の記事を参考にしてください
Google ColabでGUIを表示してみる
インストール
Google Colabにgradioをインストールするには、セルに以下のコードを挿入します。
!pip install gradio
入力した文字を、返すアプリを作成
まずは、「テキストを入力すると、入力したテキストをoutputに出力する」アプリを作成してみます。
Interface()
の引数には、処理関数fn
, 入力inputs
, 出力outputs
を指定します。
今回は入力したテキスト”text”を、”text”で出力するので、以下のようなコードになります。
また、処理関数test()
では、入力されたテキストに"input text = "
を加えて戻り値として返しています。
import gradio as gr
def test(text) :
return "input text = " + text
app = gr.Interface(fn = test, inputs="text", outputs="text")
app.launch()
たった、これだけの記述でGUIアプリは完成です。
とりあえず、実行してみてください。実行すると、以下のような画面がColabに出力されると思います。このうち、赤枠で囲んだ部分にある青文字のURLをクリックするとGUIが表示されます。
実際に動作確認してみましょう。text側に文字列を入力してSubmit
ボタンを押してください。するとoutput側に文字列が表示されたと思います。
ここでgradioの動作を整理します。gradioの動きを簡単な図で示すと下図になります。
inputで入力されたデータが関数fn
に渡され、処理結果がoutputに出力されるという流れです。gradioの動作はシンプルでかなりわかりやすいです。
次は、ちょっとだけ複雑なパターンを見てみます。
複数の入力/出力の場合
入力が複数、出力が複数ある例です。
複数の入力を受け、複数の出力を返す場合は以下のようになります。inputs
、outpus
はリストが渡せるので、リストで列挙するだけです。
3つの入力の名前は、fn
の関数により決まるようです
import gradio as gr
def greet(name, is_morning, temperature):
salutation = "Good morning" if is_morning else "Good evening"
greeting = f"{salutation} {name}. It is {temperature} degrees today"
celsius = (temperature - 32) * 5 / 9
return greeting, round(celsius, 2)
demo = gr.Interface(
fn=greet,
inputs=["text", "checkbox", gr.Slider(0, 100)],
outputs=["text", "number"],
)
demo.launch()
画像を入力する例
画像を入力する例です。画像の入力は以下のようになります。
import numpy as np
import gradio as gr
def sepia(input_img):
sepia_filter = np.array([
[0.393, 0.769, 0.189],
[0.349, 0.686, 0.168],
[0.272, 0.534, 0.131]
])
sepia_img = input_img.dot(sepia_filter.T)
sepia_img /= sepia_img.max()
return sepia_img
demo = gr.Interface(sepia, gr.Image(), "image")
demo.launch()
左の枠に画像をドロップすると、右の画面に結果が出力されます。入力の3つのアイコンは、それぞれ、ファイルアップロード、カメラ、クリップボードからペーストです。カメラを押せばカメラから直接画像を入力可能です。
PCにカメラがついていないといけません
カメラ入力ですが、MacBook Airの内蔵カメラだとエラーになったので何か条件があるかも
グローバル変数を使った処理
Submit
する度に、入力された数値の上位3つを表示するサンプルです。
scores
の結果をtrack_score
で更新しているのがポイントとなります。このように、グローバル変数に状態を格納しておくことが可能です。
import gradio as gr
scores = []
def track_score(score):
scores.append(score)
top_scores = sorted(scores, reverse=True)[:3]
return top_scores
demo = gr.Interface(
track_score,
gr.Number(label="Score"),
gr.JSON(label="Top Scores")
)
demo.launch()
garadioにはplaygroundが用意されています。書くGUIパーツの動作などは、ここで確認するとよさそうです。いくつかサンプルがあるので、機能も動かしながら確認できます。
作成したアプリの終了について
Google Colabで実行する場合、GUIを終了するには「ランタイムの接続解除して削除」するしかなさそうです。ここだけちょっと面倒な気がします。
クラス分類をやってみる(PyTorch)
以上、gradioの基本的な使い方を説明しました。わかりやすいインタフェースで使いやすい感じだったのではないでしょうか。
ここでは、実際のアプリの例として、PyTorchを使って入力された画像の分類を行うアプリを行うGUIアプリを作ってみます。
モデルの読み込み
モデルはImageNetで事前学習したresnet18を利用します。
TIMMのモデルやtorchvisionのモデルも試しましたが、この事前学習モデルがそのまま使うのには一番結果がよさそうだったので、これを選びました。
import torch
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
クラスのラベルの読み込み
1000クラスの名前(ラベル)を取得します
import json
import requests
# ImageNetのクラスラベルを取得するURL
url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
# URLからラベルを取得
response = requests.get(url)
labels = response.json()
labels[:10]
サンプル画像の読み込み
サンプル画像をダウンロードします。
!wget -O cat2.png https://github.com/aruaru0/SAM-TEST/blob/main/cat2.png?raw=true
fn
を定義
from PIL import Image
from torchvision import transforms
def predict(inp):
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
return confidences
fn
の動作を確認
img = Image.open('cat2.png')
conf = predict(img)
l = []
for e in conf:
l.append((conf[e], e))
l.sort()
l[::-1][:5]
[(0.1481473594903946, 'Egyptian Mau'),
(0.09643667936325073, 'tabby cat'),
(0.06943556666374207, 'tiger cat'),
(0.02022361569106579, 'Siamese cat'),
(0.018473317846655846, 'lynx')]
gradioで実行
import gradio as gr
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
examples=["cat2.png"]).launch()
まとめ
gradioでは、入力→変換→出力を書くだけなので非常に簡単にGUIを作成することができます。また、Google Colab上でも動作するのは個人的にはかなりポイントが高いです。
ちょっと学習モデルにGUIをつけるのには重宝しそうです。