機械学習
記事内に商品プロモーションを含む場合があります

Google ColabでGUI|gradioを使ったウェブアプリの作り方を解説

tadanori

最近、Google Colabで動くAI関連のサンプルに、ウェブアプリによるGUIが付属していることが多くなりました。「あれってどうやってるの」ということで、調べてみたところgradioを使えばできるようです。

ここでは、gradioの使い方について解説します。

gradioとは

gradioは、機械学習モデルを簡単にWebアプリケーションとして実装・公開できるPythonライブラリです。

gradioの特徴は、以下のとおりです。

  • 少ないコードでWebアプリケーションを作成可能
  • さまざまな機械学習モデルをサポート
  • Google Colab上のアプリでも利用可能

ここでは、gradioをGoogle Colab上で動かしてみます

gradioのドキュメントはhttps://www.gradio.app/にあります

PythonでGUIを作る場合、streamlitを使う方法もあります。streamlitについては以下の記事を参考にしてください

PythonでWebアプリ作成|Streamlitの使い方&チートシート
PythonでWebアプリ作成|Streamlitの使い方&チートシート

Google ColabでGUIを表示してみる

インストール

Google Colabにgradioをインストールするには、以下のコードを実行するだけです。

!pip install gradio

入力した文字を、返すアプリを作成

まずは、テキストを入力すると、テキストを返すアプリを作成してみます。

Interfaceの引数として、処理関数fn, 入力inputs, 出力outputsをしています。今回はテキスト→テキストなので以下になります。

また、test()では、文字列の先頭にinput text = をつけて返します。

import gradio as gr

def test(text) :
  return "input text = " + text

app = gr.Interface(fn = test, inputs="text", outputs="text")

app.launch()

これを実行すると以下のような画面が出力されます。赤で囲んだ部分に書かれたリンクをクリックするとGUIが表示されます。

実行画面

表示されたGUIのtextに文字列を入力して、Submitボタンを押すと、outputに文字列が表示されます。

GUI(1)

gradioの動きとしては、以下のようなイメージになります。inputsで指定された入力を関数に渡して処理結果をoutputsに出力するという非常にシンプルなものです。

複数の入力/出力の場合

複数の入力を受け、複数の出力を返す場合は以下のようになります(gradioのサンプルコードより)。inputsoutpusはリストが渡せるので、リストで列挙するだけです。

3つの入力の名前は、fnの関数により決まるようです

import gradio as gr

def greet(name, is_morning, temperature):
    salutation = "Good morning" if is_morning else "Good evening"
    greeting = f"{salutation} {name}. It is {temperature} degrees today"
    celsius = (temperature - 32) * 5 / 9
    return greeting, round(celsius, 2)

demo = gr.Interface(
    fn=greet,
    inputs=["text", "checkbox", gr.Slider(0, 100)],
    outputs=["text", "number"],
)
demo.launch()
GUI(2)

画像を入力する例

画像の入力も同じです(こちらも、gardioのサンプルです)

import numpy as np
import gradio as gr

def sepia(input_img):
    sepia_filter = np.array([
        [0.393, 0.769, 0.189], 
        [0.349, 0.686, 0.168], 
        [0.272, 0.534, 0.131]
    ])
    sepia_img = input_img.dot(sepia_filter.T)
    sepia_img /= sepia_img.max()
    return sepia_img

demo = gr.Interface(sepia, gr.Image(), "image")
demo.launch()

左の枠に画像をドロップすると、右の画面に結果が出力されます。入力の3つのアイコンは、それぞれ、ファイルアップロード、カメラ、クリップボードからペーストです。カメラを押せばカメラから直接画像を入力可能です。

PCにカメラがついていないといけません

カメラ入力ですが、MacBook Airの内蔵カメラだとエラーになったので何か条件があるかも

GUI(3)

グローバル変数を使った処理

Submitする度に、入力された数値の上位3つを表示するサンプルです。

scoresの結果をtrack_scoreで更新しているのがポイントとなります。このように、グローバル変数に状態を格納しておくことが可能です。

import gradio as gr

scores = []

def track_score(score):
    scores.append(score)
    top_scores = sorted(scores, reverse=True)[:3]
    return top_scores

demo = gr.Interface(
    track_score, 
    gr.Number(label="Score"), 
    gr.JSON(label="Top Scores")
)
demo.launch()
GUI(4)

Playground

garadioにはplaygroundが用意されています。書くGUIパーツの動作などは、ここで確認するとよさそうです。いくつかサンプルがあるので、機能も動かしながら確認できます。

gradio playground

作成したアプリの終了について

Google Colabで実行する場合、GUIを終了するには「ランタイムの接続解除して削除」するしかなさそうです。ここだけちょっと面倒な気がします。

クラス分類をやってみる(PyTorch)

gradioの基本的な使い方を説明したので、PyTorchを使って画像のクラス分類をやるものを作ってみます。

モデルの読み込み

モデルはImageNetで事前学習したresnet18を利用します。

TIMMのモデルやtorchvisionのモデルも試しましたが、この事前学習モデルがそのまま使うのには一番結果がよさそうだったので、これを選びました。

import torch
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()

クラスのラベルの読み込み

1000クラスの名前(ラベル)を取得します

import json
import requests

# ImageNetのクラスラベルを取得するURL
url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"

# URLからラベルを取得
response = requests.get(url)
labels = response.json()
labels[:10]

サンプル画像の読み込み

サンプル画像をダウンロードします。

!wget -O cat2.png https://github.com/aruaru0/SAM-TEST/blob/main/cat2.png?raw=true 

fnを定義

from PIL import Image
from torchvision import transforms

def predict(inp):
  inp = transforms.ToTensor()(inp).unsqueeze(0)
  with torch.no_grad():
    prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
    confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
  return confidences

fnの動作を確認

img = Image.open('cat2.png')
conf = predict(img)
l = []
for e in conf:
  l.append((conf[e], e))
l.sort()
l[::-1][:5]
[(0.1481473594903946, 'Egyptian Mau'),
 (0.09643667936325073, 'tabby cat'),
 (0.06943556666374207, 'tiger cat'),
 (0.02022361569106579, 'Siamese cat'),
 (0.018473317846655846, 'lynx')]

gradioで実行

import gradio as gr

gr.Interface(fn=predict,
             inputs=gr.Image(type="pil"),
             outputs=gr.Label(num_top_classes=5),
             examples=["cat2.png"]).launch()
GUI(5)クラス分類

まとめ

gradioでは、入力→変換→出力を書くだけなので非常に簡単にGUIを作成することができます。また、Google Colab上でも動作するのは個人的にはかなりポイントが高いです。

ちょっと学習モデルにGUIをつけるのには重宝しそうです。

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました